package mgjpomdp.solve;

import com.jgoodies.forms.layout.FormSpec;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.set.hash.TIntHashSet;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import mgjcommon.Pair;
import mgjcommon.PairDoubleObj;
import mgjpomdp.common.AlphaVector;
import mgjpomdp.common.MDPUtils;
import mgjpomdp.common.POMDPFlatMTJ;
import no.uib.cipr.matrix.sparse.SparseVector;

/* loaded from: input_file:mgjpomdp/solve/pbvi.class */
public class pbvi {
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v114 */
    /* JADX WARN: Type inference failed for: r0v41 */
    /* JADX WARN: Type inference failed for: r0v62 */
    /* JADX WARN: Type inference failed for: r0v75 */
    /* JADX WARN: Type inference failed for: r15v0 */
    /* JADX WARN: Type inference failed for: r15v1 */
    /* JADX WARN: Type inference failed for: r15v2 */
    /* JADX WARN: Type inference failed for: r1v20, types: [double[][]] */
    /* JADX WARN: Type inference failed for: r3v2, types: [double[]] */
    public static Pair<Set<AlphaVector>, Set<SparseVector>> solve(POMDPFlatMTJ pOMDPFlatMTJ, SparseVector[] sparseVectorArr, double[][] dArr, double d) throws Exception {
        double d2;
        double d3;
        int length = sparseVectorArr.length;
        if (length < dArr.length) {
            throw new Exception("ERROR: to make this work, the number of beliefs cannot be smaller than the number of alpha vectors");
        }
        double[] dArr2 = new double[length];
        Object[] objArr = new double[length];
        for (int i = 0; i < length; i++) {
            if (dArr.length > i) {
                objArr[i] = (double[]) dArr[i].clone();
            } else {
                objArr[i] = 0;
            }
        }
        double[] dArr3 = new double[pOMDPFlatMTJ._numS];
        double[] dArr4 = new double[pOMDPFlatMTJ._numS];
        TIntHashSet tIntHashSet = new TIntHashSet();
        do {
            tIntHashSet.clear();
            for (int i2 = 0; i2 < length; i2++) {
                dArr2[i2] = 0;
            }
            d2 = 0.0d;
            for (int i3 = 0; i3 < length; i3++) {
                double d4 = Double.NEGATIVE_INFINITY;
                double[] dArr5 = null;
                for (int i4 = 0; i4 < pOMDPFlatMTJ._numA; i4++) {
                    Arrays.fill(dArr3, FormSpec.NO_GROW);
                    for (int i5 = 0; i5 < pOMDPFlatMTJ._numObs; i5++) {
                        double[] dArr6 = null;
                        double d5 = Double.NEGATIVE_INFINITY;
                        for (int i6 = 0; i6 < length; i6++) {
                            if (objArr[i6] != null) {
                                pOMDPFlatMTJ._TORow[i4][i5].mult(pOMDPFlatMTJ._gamma, objArr[i6], dArr4);
                                double dot = sparseVectorArr[i3].dot(dArr4);
                                if (dot > d5) {
                                    d5 = dot;
                                    dArr6 = (double[]) dArr4.clone();
                                }
                            }
                        }
                        MDPUtils.add(dArr6, dArr3);
                    }
                    MDPUtils.add(pOMDPFlatMTJ._R[i4], pOMDPFlatMTJ._nonZeroR[i4], dArr3);
                    double dot2 = sparseVectorArr[i3].dot(dArr3);
                    if (dot2 > d4) {
                        d4 = dot2;
                        dArr5 = (double[]) dArr3.clone();
                    }
                }
                PairDoubleObj<double[]> evaluateAndBestAlpha = evaluateAndBestAlpha(sparseVectorArr[i3], objArr);
                double evaluateBeliefPoint = evaluateBeliefPoint(sparseVectorArr[i3], dArr2);
                if (d4 > evaluateBeliefPoint && d4 > evaluateAndBestAlpha.first) {
                    d3 = d4;
                    dArr2[i3] = dArr5;
                } else if (evaluateAndBestAlpha.first > evaluateBeliefPoint) {
                    dArr2[i3] = evaluateAndBestAlpha.second;
                    d3 = evaluateAndBestAlpha.first;
                } else {
                    tIntHashSet.add(i3);
                    d3 = evaluateBeliefPoint;
                }
                double abs = Math.abs(evaluateAndBestAlpha.first - d3);
                if (abs > d2) {
                    d2 = abs;
                }
            }
            double[] dArr7 = objArr;
            objArr = dArr2;
            dArr2 = dArr7;
        } while (d2 > d);
        HashSet hashSet = new HashSet();
        for (int i7 = 0; i7 < length; i7++) {
            if (objArr[i7] != 0) {
                hashSet.add(new AlphaVector((double[]) objArr[i7]));
            }
        }
        HashSet hashSet2 = new HashSet();
        TIntIterator it = tIntHashSet.iterator();
        while (it.hasNext()) {
            hashSet2.add(sparseVectorArr[it.next()]);
        }
        return new Pair<>(hashSet, hashSet2);
    }

    public static double evaluateBeliefPoint(SparseVector sparseVector, double[][] dArr) {
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] != null) {
                double dot = sparseVector.dot(dArr[i]);
                if (dot > d) {
                    d = dot;
                }
            }
        }
        return d;
    }

    public static PairDoubleObj<double[]> evaluateAndBestAlpha(SparseVector sparseVector, double[][] dArr) {
        double d = Double.NEGATIVE_INFINITY;
        double[] dArr2 = null;
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] != null) {
                double dot = sparseVector.dot(dArr[i]);
                if (dot > d) {
                    d = dot;
                    dArr2 = dArr[i];
                }
            }
        }
        return new PairDoubleObj<>(d, dArr2);
    }

    public static void main(String[] strArr) throws Exception {
    }
}
