package Facemorph.psm;

import Facemorph.PCA;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStreamReader;
import java.io.PrintStream;
import java.io.StreamTokenizer;
import java.util.ArrayList;
import java.util.Random;
import java.util.Vector;

/* loaded from: input_file:Facemorph/psm/DiagonalGMM.class */
public class DiagonalGMM {
    ArrayList<DiagonalGaussian> clusters;
    ArrayList<Double> weights;

    public DiagonalGMM(int i) {
        this.weights = new ArrayList<>(i);
        this.clusters = new ArrayList<>(i);
        for (int i2 = 0; i2 < i; i2++) {
            this.weights.add(new Double(0.0d));
            this.clusters.add(new DiagonalGaussian(0));
        }
    }

    public DiagonalGMM() {
        this.weights = new ArrayList<>();
        this.clusters = new ArrayList<>();
    }

    public void set(int i, DiagonalGaussian diagonalGaussian, double d) {
        this.clusters.set(i, diagonalGaussian);
        this.weights.set(i, Double.valueOf(d));
    }

    public DiagonalGaussian getCluster(int i) {
        if (i >= this.clusters.size()) {
            return null;
        }
        return this.clusters.get(i);
    }

    public double getWeight(int i) {
        if (i >= this.weights.size()) {
            return 0.0d;
        }
        return this.weights.get(i).doubleValue();
    }

    public int getCount() {
        return this.clusters.size();
    }

    public void normalise() {
        normalise(this.weights);
    }

    public static void normalise(ArrayList<Double> arrayList) {
        double d = 0.0d;
        for (int i = 0; i < arrayList.size(); i++) {
            d += arrayList.get(i).doubleValue();
        }
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            arrayList.set(i2, Double.valueOf(arrayList.get(i2).doubleValue() / d));
        }
    }

    public void write(String str) {
        try {
            write(new PrintStream(new FileOutputStream(str)));
        } catch (Exception e) {
            System.out.println(e);
        }
    }

    public void write(PrintStream printStream) {
        printStream.println("Clusters " + this.clusters.size());
        for (int i = 0; i < this.clusters.size(); i++) {
            printStream.println("Cluster " + i);
            printStream.println("Weight " + this.weights.get(i));
            printStream.println("DiagonalGaussian");
            this.clusters.get(i).write(printStream);
        }
        printStream.close();
    }

    public void read(String str) {
        try {
            read(new DataInputStream(new FileInputStream(str)));
        } catch (Exception e) {
            System.out.println(e);
        }
    }

    public void read(DataInputStream dataInputStream) {
        StreamTokenizer streamTokenizer = new StreamTokenizer(new InputStreamReader(dataInputStream));
        streamTokenizer.parseNumbers();
        read(streamTokenizer);
    }

    public void read(StreamTokenizer streamTokenizer) {
        try {
            streamTokenizer.nextToken();
            streamTokenizer.nextToken();
            int i = (int) streamTokenizer.nval;
            this.clusters = new ArrayList<>(i);
            this.weights = new ArrayList<>(i);
            for (int i2 = 0; i2 < i; i2++) {
                DiagonalGaussian diagonalGaussian = new DiagonalGaussian(0);
                streamTokenizer.nextToken();
                streamTokenizer.nextToken();
                streamTokenizer.nextToken();
                this.weights.add(Double.valueOf(PCA.readDouble(streamTokenizer)));
                streamTokenizer.nextToken();
                diagonalGaussian.read(streamTokenizer);
                this.clusters.add(diagonalGaussian);
            }
        } catch (Exception e) {
            System.out.println(e);
        }
    }

    public void random(Random random, DiagonalGaussian diagonalGaussian) {
        for (int i = 0; i < this.clusters.size(); i++) {
            this.weights.set(i, Double.valueOf(random.nextDouble()));
            double[] randomSample = diagonalGaussian.getRandomSample(random);
            double[] dArr = new double[randomSample.length];
            for (int i2 = 0; i2 < randomSample.length; i2++) {
                dArr[i2] = 10000.0d * diagonalGaussian.variance[i2];
            }
            this.clusters.set(i, new DiagonalGaussian(randomSample, dArr));
        }
        normalise();
    }

    public Vector EM(Vector vector) {
        Vector vector2 = new Vector(vector.size());
        double[] dArr = (double[]) vector.elementAt(0);
        for (int i = 0; i < vector.size(); i++) {
            vector2.add(new double[this.clusters.size()]);
        }
        Random random = new Random();
        DiagonalGaussian diagonalGaussian = new DiagonalGaussian(dArr.length);
        diagonalGaussian.build(vector);
        for (int i2 = 0; i2 < diagonalGaussian.variance.length; i2++) {
            double[] dArr2 = diagonalGaussian.variance;
            int i3 = i2;
            dArr2[i3] = dArr2[i3] / 1000.0d;
        }
        random(random, diagonalGaussian);
        boolean z = false;
        int i4 = 0;
        while (!z) {
            zeroSampleWeights(vector2);
            weightSamples(vector, vector2);
            normaliseSampleWeights(vector2);
            reestimate(vector, vector2);
            normalise();
            i4++;
            if (i4 > 10) {
                z = true;
            }
        }
        return vector2;
    }

    public Vector EM(Vector vector, Vector<Double> vector2) {
        Vector vector3 = new Vector(vector.size());
        double[] dArr = (double[]) vector.elementAt(0);
        for (int i = 0; i < vector.size(); i++) {
            vector3.add(new double[this.clusters.size()]);
        }
        Random random = new Random();
        DiagonalGaussian diagonalGaussian = new DiagonalGaussian(dArr.length);
        diagonalGaussian.build(vector, vector2);
        for (int i2 = 0; i2 < diagonalGaussian.variance.length; i2++) {
            double[] dArr2 = diagonalGaussian.variance;
            int i3 = i2;
            dArr2[i3] = dArr2[i3] / 1000.0d;
        }
        random(random, diagonalGaussian);
        boolean z = false;
        int i4 = 0;
        while (!z) {
            zeroSampleWeights(vector3);
            weightSamples(vector, vector3);
            normaliseSampleWeights(vector3, vector2);
            reestimate(vector, vector3);
            normalise();
            i4++;
            if (i4 > 10) {
                z = true;
            }
        }
        return vector3;
    }

    public static void zeroSampleWeights(Vector vector) {
        for (int i = 0; i < vector.size(); i++) {
            double[] dArr = (double[]) vector.elementAt(i);
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr[i2] = 0.0d;
            }
        }
    }

    public static void normaliseSampleWeights(Vector vector) {
        for (int i = 0; i < vector.size(); i++) {
            double[] dArr = (double[]) vector.elementAt(i);
            double d = 0.0d;
            for (double d2 : dArr) {
                d += d2;
            }
            if (d > 0.0d) {
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] / d;
                }
            }
        }
    }

    public static void normaliseSampleWeights(Vector vector, Vector<Double> vector2) {
        for (int i = 0; i < vector.size(); i++) {
            double[] dArr = (double[]) vector.elementAt(i);
            double d = 0.0d;
            for (double d2 : dArr) {
                d += d2;
            }
            double doubleValue = vector2.get(i).doubleValue();
            if (d > 0.0d) {
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] * (doubleValue / d);
                }
            }
        }
    }

    public static void normaliseSampleWeights2(Vector vector) {
        for (int i = 0; i < vector.size(); i++) {
            double[] dArr = (double[]) vector.elementAt(i);
            double d = dArr[0];
            int i2 = 0;
            for (int i3 = 1; i3 < dArr.length; i3++) {
                if (dArr[i3] > d) {
                    d = dArr[i3];
                    i2 = i3;
                }
            }
            for (int i4 = 0; i4 < dArr.length; i4++) {
                dArr[i4] = 0.0d;
            }
            dArr[i2] = 1.0d;
        }
    }

    public void weightSamples(Vector vector, Vector vector2) {
        for (int i = 0; i < vector.size(); i++) {
            double[] dArr = (double[]) vector.elementAt(i);
            double[] dArr2 = (double[]) vector2.elementAt(i);
            for (int i2 = 0; i2 < this.clusters.size(); i2++) {
                dArr2[i2] = this.weights.get(i2).doubleValue() * this.clusters.get(i2).probability(dArr);
            }
        }
    }

    public void reestimate(Vector vector, Vector vector2) {
        for (int i = 0; i < this.clusters.size(); i++) {
            this.weights.set(i, Double.valueOf(this.clusters.get(i).build(vector, vector2, i)));
        }
    }

    public static int[] rank(double[] dArr) {
        int[] iArr = new int[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            iArr[i] = i;
        }
        rank(dArr, iArr, 0, dArr.length);
        return iArr;
    }

    public static void rank(double[] dArr, int[] iArr, int i, int i2) {
        if (i2 - i <= 1) {
            return;
        }
        double d = dArr[i];
        int i3 = i;
        int i4 = iArr[i];
        for (int i5 = i + 1; i5 < i2; i5++) {
            if (dArr[i5] > d) {
                dArr[i3] = dArr[i5];
                iArr[i3] = iArr[i5];
                dArr[i5] = dArr[i3 + 1];
                iArr[i5] = iArr[i3 + 1];
                dArr[i3 + 1] = d;
                iArr[i3 + 1] = i4;
                i3++;
            }
        }
        rank(dArr, iArr, i, i3);
        rank(dArr, iArr, i3 + 1, i2);
    }

    public double[] getRandomSample(Random random) {
        int i = 0;
        double nextDouble = random.nextDouble();
        double doubleValue = this.weights.get(0).doubleValue();
        while (true) {
            double d = nextDouble - doubleValue;
            if (i >= this.clusters.size() - 1 || d <= 0.0d) {
                break;
            }
            i++;
            nextDouble = d;
            doubleValue = this.weights.get(i).doubleValue();
        }
        return this.clusters.get(i).getRandomSample(random);
    }

    public double probability(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < this.clusters.size(); i++) {
            d += this.weights.get(i).doubleValue() * this.clusters.get(i).probability(dArr);
        }
        return d;
    }

    public double[] probabilities(double[] dArr) {
        double[] dArr2 = new double[this.clusters.size()];
        for (int i = 0; i < this.clusters.size(); i++) {
            dArr2[i] = this.clusters.get(i).probability(dArr);
        }
        return dArr2;
    }

    public double maxProbabilty() {
        double doubleValue = this.weights.get(0).doubleValue();
        int i = 0;
        for (int i2 = 1; i2 < this.weights.size(); i2++) {
            if (this.weights.get(i2).doubleValue() > doubleValue) {
                doubleValue = this.weights.get(i2).doubleValue();
                i = i2;
            }
        }
        return probability(this.clusters.get(i).mean);
    }

    public double[] maxProbabilties() {
        double[] dArr = new double[this.clusters.size()];
        for (int i = 0; i < this.weights.size(); i++) {
            dArr[i] = probability(this.clusters.get(i).mean);
        }
        return dArr;
    }

    public static Vector readVectors(String str) {
        try {
            return readVectors(new DataInputStream(new FileInputStream(str)));
        } catch (Exception e) {
            System.out.println(e);
            return null;
        }
    }

    public static Vector readVectors(DataInputStream dataInputStream) {
        StreamTokenizer streamTokenizer = new StreamTokenizer(new InputStreamReader(dataInputStream));
        streamTokenizer.parseNumbers();
        return readVectors(streamTokenizer);
    }

    public static Vector readVectors(StreamTokenizer streamTokenizer) {
        Vector vector = new Vector();
        try {
            streamTokenizer.nextToken();
            int i = (int) streamTokenizer.nval;
            streamTokenizer.nextToken();
            int i2 = (int) streamTokenizer.nval;
            for (int i3 = 0; i3 < i; i3++) {
                double[] dArr = new double[i2];
                for (int i4 = 0; i4 < i2; i4++) {
                    dArr[i4] = PCA.readDouble(streamTokenizer);
                }
                vector.add(dArr);
            }
            return vector;
        } catch (Exception e) {
            System.out.println(e);
            return null;
        }
    }

    public static void writeVectors(Vector vector, String str) {
        try {
            PrintStream printStream = new PrintStream(new FileOutputStream(str));
            writeVectors(vector, printStream);
            printStream.close();
        } catch (Exception e) {
            System.out.println(e);
        }
    }

    public static void writeVectors(Vector vector, PrintStream printStream) {
        printStream.println("" + vector.size());
        printStream.println("" + ((double[]) vector.elementAt(0)).length);
        for (int i = 0; i < vector.size(); i++) {
            for (double d : (double[]) vector.elementAt(i)) {
                printStream.print(d + " ");
            }
            printStream.println("");
        }
    }

    public static Vector readMFCC(String str) {
        try {
            return readMFCC(new DataInputStream(new FileInputStream(str)));
        } catch (Exception e) {
            System.out.println(e);
            return null;
        }
    }

    public static Vector readMFCC(DataInputStream dataInputStream) {
        StreamTokenizer streamTokenizer = new StreamTokenizer(new InputStreamReader(dataInputStream));
        streamTokenizer.parseNumbers();
        return readMFCC(streamTokenizer);
    }

    public static Vector readMFCC(StreamTokenizer streamTokenizer) {
        Vector vector = new Vector();
        try {
            streamTokenizer.nextToken();
            int i = (int) streamTokenizer.nval;
            System.out.println("count = " + i);
            for (int i2 = 0; i2 < i; i2++) {
                streamTokenizer.nextToken();
                int i3 = ((int) streamTokenizer.nval) + 1;
                double[] dArr = new double[i3];
                for (int i4 = 0; i4 < i3; i4++) {
                    dArr[i4] = PCA.readDouble(streamTokenizer);
                }
                vector.add(dArr);
            }
            return vector;
        } catch (Exception e) {
            System.out.println(e);
            return null;
        }
    }

    public static double[] append(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[dArr.length + dArr2.length];
        int i = 0;
        while (i < dArr.length) {
            dArr3[i] = dArr[i];
            i++;
        }
        int i2 = 0;
        while (i2 < dArr2.length) {
            dArr3[i] = dArr2[i2];
            i2++;
            i++;
        }
        return dArr3;
    }

    public static Vector appendVectors(Vector vector, Vector vector2) {
        if (vector.size() != vector2.size()) {
            return null;
        }
        Vector vector3 = new Vector();
        for (int i = 0; i < vector.size(); i++) {
            vector3.add(append((double[]) vector.elementAt(i), (double[]) vector2.elementAt(i)));
        }
        return vector3;
    }

    public static Vector resampleVectors(Vector vector, int i) {
        Vector vector2 = new Vector();
        for (int i2 = 0; i2 < i; i2++) {
            double size = (i2 * vector.size()) / i;
            int i3 = (int) size;
            int i4 = i3 < vector.size() - 1 ? i3 + 1 : i3;
            double d = size - i3;
            double[] dArr = (double[]) vector.elementAt(i3);
            double[] dArr2 = (double[]) vector.elementAt(i4);
            double[] dArr3 = new double[dArr.length];
            for (int i5 = 0; i5 < dArr3.length; i5++) {
                dArr3[i5] = ((1.0d - d) * dArr[i5]) + (d * dArr2[i5]);
            }
            vector2.add(dArr3);
        }
        return vector2;
    }

    public static Vector createDynamics(Vector vector) {
        Vector vector2 = new Vector(vector.size());
        for (int i = 0; i < vector.size() - 1; i++) {
            double[] dArr = (double[]) vector.elementAt(i);
            double[] dArr2 = (double[]) vector.elementAt(i + 1);
            double[] dArr3 = new double[dArr.length * 2];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr3[i2] = dArr[i2];
                dArr3[i2 + dArr.length] = dArr2[i2];
            }
            vector2.add(dArr3);
        }
        double[] dArr4 = (double[]) vector.elementAt(vector.size() - 1);
        double[] dArr5 = new double[dArr4.length * 2];
        for (int i3 = 0; i3 < dArr4.length; i3++) {
            dArr5[i3] = dArr4[i3];
            dArr5[i3 + dArr4.length] = dArr4[i3];
        }
        vector2.add(dArr5);
        return vector2;
    }

    public static Vector createDynamics2(Vector vector) {
        Vector vector2 = new Vector(vector.size());
        for (int i = 0; i < vector.size() - 1; i++) {
            double[] dArr = (double[]) vector.elementAt(i);
            double[] dArr2 = (double[]) vector.elementAt(i + 1);
            double[] dArr3 = new double[dArr.length];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr3[i2] = dArr2[i2] - dArr[i2];
            }
            vector2.add(dArr3);
        }
        double[] dArr4 = (double[]) vector.elementAt(vector.size() - 1);
        double[] dArr5 = new double[dArr4.length];
        for (int i3 = 0; i3 < dArr4.length; i3++) {
            dArr5[i3] = 0.0d;
        }
        vector2.add(dArr5);
        return vector2;
    }

    public double prob_sample_given_GMM(Vector vector) {
        double d = 0.0d;
        for (int i = 0; i < vector.size(); i++) {
            double[] dArr = (double[]) vector.elementAt(i);
            if (dArr != null) {
                d += probability(dArr);
            } else {
                System.out.println("Null sample " + i);
            }
        }
        return d;
    }

    public void copy(DiagonalGMM diagonalGMM) {
        this.clusters = diagonalGMM.clusters;
        this.weights = diagonalGMM.weights;
    }

    public void add(DiagonalGaussian diagonalGaussian) {
        this.clusters.add(diagonalGaussian);
        this.weights.add(Double.valueOf(1.0d));
    }

    public static void main(String[] strArr) {
    }
}
