package mgjpomdp.solve;

import mgjcommon.Pair;
import mgjcommon.PairDoubleObj;
import mgjpomdp.common.AlphaVector;
import mgjpomdp.common.POMDPFlatMTJ;
import mgjpomdp.solve.bounds.FIBMTJ;
import mgjpomdp.solve.bounds.LBData;
import mgjpomdp.solve.bounds.UBData;
import mgjpomdp.solve.bounds.UpperBound;
import no.uib.cipr.matrix.sparse.SparseVector;

/* loaded from: input_file:mgjpomdp/solve/Optimus.class */
public class Optimus {
    private POMDPFlatMTJ _pomdp;

    public Optimus(POMDPFlatMTJ pOMDPFlatMTJ) throws Exception {
        this._pomdp = pOMDPFlatMTJ;
    }

    public void optimusOnePoint(int i) throws Exception {
        System.currentTimeMillis();
        UpperBound.setUBTypeLP();
        UBData uBData = new UBData();
        FIBMTJ fibmtj = new FIBMTJ();
        fibmtj.solve(this._pomdp, 0.001d, 0);
        for (double[] dArr : fibmtj._Q) {
            uBData._alphaVectorsSet.add(new AlphaVector(dArr));
        }
        double bound = fibmtj.getBound(this._pomdp._initBelief);
        System.out.println("the initial FIB upper bound in the initial belief state " + bound);
        uBData._beliefsValuesMap.put(this._pomdp._initBelief, bound);
        uBData.reloadArrays();
        double doubleValue = Lookahead(this._pomdp, this._pomdp._initBelief, uBData).second.doubleValue();
        if (bound > doubleValue) {
            bound = doubleValue;
            System.out.println("the initial FIB upper bound in the initial belief state " + bound);
            uBData._beliefsValuesMap.put(this._pomdp._initBelief, bound);
            uBData.reloadArrays();
        }
        int i2 = 1;
        do {
            POMDPFlatMTJ constructAugmentedPOMDP = POMDPFlatMTJ.constructAugmentedPOMDP(this._pomdp, uBData._beliefs, uBData._values, uBData._alphaVectors);
            fibmtj.solve(constructAugmentedPOMDP, 1.0E-6d, 0);
            uBData._alphaVectorsSet.clear();
            for (double[] dArr2 : fibmtj._Q) {
                double[] dArr3 = new double[this._pomdp._numS];
                System.arraycopy(dArr2, 0, dArr3, 0, this._pomdp._numS);
                uBData._alphaVectorsSet.add(new AlphaVector(dArr3));
            }
            uBData.reloadAlphaArray();
            for (int i3 = 0; i3 < uBData._beliefs.length; i3++) {
                double d = Double.NEGATIVE_INFINITY;
                for (int i4 = 0; i4 < this._pomdp._numA; i4++) {
                    double d2 = fibmtj._Q[i4][this._pomdp._numS + i3];
                    if (d2 > d) {
                        d = d2;
                    }
                }
                uBData._beliefsValuesMap.put(uBData._beliefs[i3], d);
            }
            uBData.reloadValuesArray();
            UBData uBData2 = new UBData();
            for (double[] dArr4 : fibmtj._Q) {
                uBData2._alphaVectorsSet.add(new AlphaVector(dArr4));
            }
            uBData2._beliefsValuesMap.put(constructAugmentedPOMDP._initBelief, fibmtj.getBound(constructAugmentedPOMDP._initBelief));
            uBData2.reloadArrays();
            double doubleValue2 = Lookahead(constructAugmentedPOMDP, constructAugmentedPOMDP._initBelief, uBData2).second.doubleValue();
            bound = uBData._beliefsValuesMap.get(this._pomdp._initBelief);
            double doubleValue3 = Lookahead(this._pomdp, this._pomdp._initBelief, uBData).second.doubleValue();
            if (bound > doubleValue3) {
                bound = doubleValue3;
                uBData._beliefsValuesMap.put(this._pomdp._initBelief, doubleValue3);
            }
            if (bound > doubleValue2) {
                bound = doubleValue2;
                uBData._beliefsValuesMap.put(this._pomdp._initBelief, doubleValue2);
            }
            System.out.println("the upper bound in the initial belief after " + i2 + " iterations is " + bound);
            i2++;
            if (Math.abs(bound - bound) <= 1.0E-16d) {
                return;
            }
        } while (i2 < 20);
    }

    public void optimus5Points(int i) throws Exception {
        System.currentTimeMillis();
        UpperBound.setUBTypeLP();
        UBData uBData = new UBData();
        FIBMTJ fibmtj = new FIBMTJ();
        fibmtj.solve(this._pomdp, 0.001d, 0);
        for (double[] dArr : fibmtj._Q) {
            uBData._alphaVectorsSet.add(new AlphaVector(dArr));
        }
        double bound = fibmtj.getBound(this._pomdp._initBelief);
        System.out.println("the initial FIB upper bound in the initial belief state " + bound);
        uBData._beliefsValuesMap.put(this._pomdp._initBelief, bound);
        SparseVector sparseVector = new SparseVector(this._pomdp._numS);
        sparseVector.set(0, 0.97d);
        sparseVector.set(1, 0.03d);
        uBData._beliefsValuesMap.put(sparseVector, fibmtj.getBound(sparseVector));
        SparseVector sparseVector2 = new SparseVector(this._pomdp._numS);
        sparseVector2.set(0, 0.03d);
        sparseVector2.set(1, 0.97d);
        uBData._beliefsValuesMap.put(sparseVector2, fibmtj.getBound(sparseVector2));
        SparseVector sparseVector3 = new SparseVector(this._pomdp._numS);
        sparseVector3.set(0, 0.85d);
        sparseVector3.set(1, 0.15d);
        uBData._beliefsValuesMap.put(sparseVector3, fibmtj.getBound(sparseVector3));
        SparseVector sparseVector4 = new SparseVector(this._pomdp._numS);
        sparseVector4.set(0, 0.15d);
        sparseVector4.set(1, 0.85d);
        uBData._beliefsValuesMap.put(sparseVector4, fibmtj.getBound(sparseVector4));
        uBData.reloadArrays();
        int i2 = 1;
        do {
            POMDPFlatMTJ constructAugmentedPOMDP = POMDPFlatMTJ.constructAugmentedPOMDP(this._pomdp, uBData._beliefs, uBData._values, uBData._alphaVectors);
            fibmtj.solve(constructAugmentedPOMDP, 1.0E-6d, 0);
            uBData._alphaVectorsSet.clear();
            for (double[] dArr2 : fibmtj._Q) {
                double[] dArr3 = new double[this._pomdp._numS];
                System.arraycopy(dArr2, 0, dArr3, 0, this._pomdp._numS);
                uBData._alphaVectorsSet.add(new AlphaVector(dArr3));
            }
            uBData.reloadAlphaArray();
            for (int i3 = 0; i3 < uBData._beliefs.length; i3++) {
                double d = Double.NEGATIVE_INFINITY;
                for (int i4 = 0; i4 < this._pomdp._numA; i4++) {
                    double d2 = fibmtj._Q[i4][this._pomdp._numS + i3];
                    if (d2 > d) {
                        d = d2;
                    }
                }
                uBData._beliefsValuesMap.put(uBData._beliefs[i3], d);
            }
            uBData.reloadValuesArray();
            UBData uBData2 = new UBData();
            for (double[] dArr4 : fibmtj._Q) {
                uBData2._alphaVectorsSet.add(new AlphaVector(dArr4));
            }
            uBData2._beliefsValuesMap.put(constructAugmentedPOMDP._initBelief, fibmtj.getBound(constructAugmentedPOMDP._initBelief));
            uBData2.reloadArrays();
            double doubleValue = Lookahead(constructAugmentedPOMDP, constructAugmentedPOMDP._initBelief, uBData2).second.doubleValue();
            bound = uBData._beliefsValuesMap.get(this._pomdp._initBelief);
            double doubleValue2 = Lookahead(this._pomdp, this._pomdp._initBelief, uBData).second.doubleValue();
            if (bound > doubleValue2) {
                bound = doubleValue2;
                uBData._beliefsValuesMap.put(this._pomdp._initBelief, doubleValue2);
            }
            if (bound > doubleValue) {
                bound = doubleValue;
                uBData._beliefsValuesMap.put(this._pomdp._initBelief, doubleValue);
            }
            System.out.println("the upper bound in the initial belief after " + i2 + " iterations is " + bound);
            i2++;
            if (Math.abs(bound - bound) <= 1.0E-16d) {
                return;
            }
        } while (i2 < 20);
    }

    public void optimusAlphaVectors(int i) throws Exception {
        SparseVector sparseVector = new SparseVector(this._pomdp._numS);
        sparseVector.set(0, 0.55d);
        sparseVector.set(1, 0.45d);
        System.currentTimeMillis();
        UpperBound.setUBTypeLP();
        UBData uBData = new UBData();
        FIBMTJ fibmtj = new FIBMTJ();
        fibmtj.solve(this._pomdp, 0.001d, 0);
        for (double[] dArr : fibmtj._Q) {
            uBData._alphaVectorsSet.add(new AlphaVector(dArr));
        }
        double bound = fibmtj.getBound(this._pomdp._initBelief);
        System.out.println("the initial FIB upper bound in the initial belief state " + bound);
        uBData._beliefsValuesMap.put(this._pomdp._initBelief, bound);
        System.out.println("query on original FIB = " + fibmtj.getBound(sparseVector));
        SparseVector sparseVector2 = new SparseVector(this._pomdp._numS);
        sparseVector2.set(0, 0.97d);
        sparseVector2.set(1, 0.03d);
        uBData._beliefsValuesMap.put(sparseVector2, fibmtj.getBound(sparseVector2));
        SparseVector sparseVector3 = new SparseVector(this._pomdp._numS);
        sparseVector3.set(0, 0.03d);
        sparseVector3.set(1, 0.97d);
        uBData._beliefsValuesMap.put(sparseVector3, fibmtj.getBound(sparseVector3));
        SparseVector sparseVector4 = new SparseVector(this._pomdp._numS);
        sparseVector4.set(0, 0.85d);
        sparseVector4.set(1, 0.15d);
        uBData._beliefsValuesMap.put(sparseVector4, fibmtj.getBound(sparseVector4));
        SparseVector sparseVector5 = new SparseVector(this._pomdp._numS);
        sparseVector5.set(0, 0.15d);
        sparseVector5.set(1, 0.85d);
        uBData._beliefsValuesMap.put(sparseVector5, fibmtj.getBound(sparseVector5));
        uBData.reloadArrays();
        SparseVector sparseVector6 = UpperBound.getBound(sparseVector, uBData._beliefs, uBData._values, uBData._alphaVectors).second;
        int i2 = 1;
        do {
            POMDPFlatMTJ constructAugmentedPOMDP = POMDPFlatMTJ.constructAugmentedPOMDP(this._pomdp, uBData._beliefs, uBData._values, uBData._alphaVectors);
            fibmtj.solve(constructAugmentedPOMDP, 1.0E-6d, 0);
            uBData._alphaVectorsSet.clear();
            for (double[] dArr2 : fibmtj._Q) {
                double[] dArr3 = new double[this._pomdp._numS];
                System.arraycopy(dArr2, 0, dArr3, 0, this._pomdp._numS);
                uBData._alphaVectorsSet.add(new AlphaVector(dArr3));
            }
            uBData.reloadAlphaArray();
            for (int i3 = 0; i3 < uBData._beliefs.length; i3++) {
                double d = Double.NEGATIVE_INFINITY;
                for (int i4 = 0; i4 < this._pomdp._numA; i4++) {
                    double d2 = fibmtj._Q[i4][this._pomdp._numS + i3];
                    if (d2 > d) {
                        d = d2;
                    }
                }
                uBData._beliefsValuesMap.put(uBData._beliefs[i3], d);
            }
            uBData.reloadValuesArray();
            UBData uBData2 = new UBData();
            for (double[] dArr4 : fibmtj._Q) {
                uBData2._alphaVectorsSet.add(new AlphaVector(dArr4));
            }
            uBData2._beliefsValuesMap.put(constructAugmentedPOMDP._initBelief, fibmtj.getBound(constructAugmentedPOMDP._initBelief));
            uBData2.reloadArrays();
            double doubleValue = Lookahead(constructAugmentedPOMDP, constructAugmentedPOMDP._initBelief, uBData2).second.doubleValue();
            bound = uBData._beliefsValuesMap.get(this._pomdp._initBelief);
            double doubleValue2 = Lookahead(this._pomdp, this._pomdp._initBelief, uBData).second.doubleValue();
            if (bound > doubleValue2) {
                bound = doubleValue2;
                uBData._beliefsValuesMap.put(this._pomdp._initBelief, doubleValue2);
            }
            if (bound > doubleValue) {
                bound = doubleValue;
                uBData._beliefsValuesMap.put(this._pomdp._initBelief, doubleValue);
            }
            System.out.println("the upper bound in the initial belief after " + i2 + " iterations is " + bound);
            PairDoubleObj<SparseVector> bound2 = UpperBound.getBound(sparseVector, uBData._beliefs, uBData._values, uBData._alphaVectors);
            System.out.println("query eval in iteration " + i2 + " is " + bound2.first);
            System.out.println("query eval in iteration " + i2 + " is " + fibmtj.getBound(bound2.second));
            System.out.println("query eval in iteration " + i2 + " is " + fibmtj.getBound(sparseVector6));
            SparseVector sparseVector7 = new SparseVector(this._pomdp._numS);
            this._pomdp._TO[0][0].transMult(this._pomdp._initBelief, sparseVector7);
            sparseVector7.scale(1.0d / sparseVector7.sum());
            PairDoubleObj<SparseVector> bound3 = UpperBound.getBound(sparseVector, uBData._beliefs, uBData._values, uBData._alphaVectors);
            SparseVector sparseVector8 = new SparseVector(constructAugmentedPOMDP._numS);
            constructAugmentedPOMDP._TO[0][0].transMult(constructAugmentedPOMDP._initBelief, sparseVector8);
            sparseVector8.scale(1.0d / sparseVector8.sum());
            System.out.println(bound3.second.toStr() + " - " + sparseVector8.toStr());
            i2++;
            if (Math.abs(bound - bound) <= 1.0E-16d) {
                return;
            }
        } while (i2 < 20);
    }

    public Pair<Integer, Double> Lookahead(POMDPFlatMTJ pOMDPFlatMTJ, SparseVector sparseVector, UBData uBData) throws Exception {
        double d = Double.NEGATIVE_INFINITY;
        int i = -1;
        SparseVector sparseVector2 = new SparseVector(pOMDPFlatMTJ._numS);
        for (int i2 = 0; i2 < pOMDPFlatMTJ._numA; i2++) {
            double dot = sparseVector.dot(pOMDPFlatMTJ._R[i2]);
            for (int i3 = 0; i3 < pOMDPFlatMTJ._numObs; i3++) {
                pOMDPFlatMTJ._TO[i2][i3].transMult(sparseVector, sparseVector2);
                dot += pOMDPFlatMTJ._gamma * UpperBound.getBound(sparseVector2, uBData._beliefs, uBData._values, uBData._alphaVectors).first;
            }
            if (dot > d) {
                d = dot;
                i = i2;
            }
        }
        return new Pair<>(Integer.valueOf(i), Double.valueOf(d));
    }

    public static void main(String[] strArr) throws Exception {
        testGapMinCheapEaten();
        testGapMinTiger75();
        POMDPFlatMTJ pOMDPFlatMTJ = new POMDPFlatMTJ("data/tiger.75.POMDP", 0);
        Optimus optimus = new Optimus(pOMDPFlatMTJ);
        optimus.optimusOnePoint(1);
        optimus.optimus5Points(1);
        optimus.optimusAlphaVectors(1);
        Pair<LBData, UBData> solve = new GapMin(pOMDPFlatMTJ).solve(0);
        SparseVector sparseVector = new SparseVector(pOMDPFlatMTJ._numS);
        sparseVector.set(0, 0.66d);
        sparseVector.set(1, 0.34d);
        for (double[] dArr : solve.first._alphaVectors) {
            for (double d : dArr) {
                System.out.print(d + " ");
            }
            System.out.println();
            System.out.println("init ub value is " + sparseVector.dot(dArr));
        }
    }

    public static void testGapMinTiger75() throws Exception {
        POMDPFlatMTJ pOMDPFlatMTJ = new POMDPFlatMTJ("data/tiger.75.POMDP", 0);
        Pair<LBData, UBData> solve = new GapMin(pOMDPFlatMTJ).solve(0);
        for (double[] dArr : solve.second._alphaVectors) {
            System.out.println("init ub value is " + pOMDPFlatMTJ._initBelief.dot(dArr));
        }
        System.out.println("the final upper bound in the initial belief is " + solve.second._beliefsValuesMap.get(pOMDPFlatMTJ._initBelief));
    }

    public static void testGapMinCheapEaten() throws Exception {
        POMDPFlatMTJ pOMDPFlatMTJ = new POMDPFlatMTJ("data/tigerCheapEaten.75.POMDP", 0);
        Pair<LBData, UBData> solve = new GapMin(pOMDPFlatMTJ).solve(0);
        for (double[] dArr : solve.second._alphaVectors) {
            System.out.println("init ub value is " + pOMDPFlatMTJ._initBelief.dot(dArr));
        }
        System.out.println("the final upper bound in the initial belief is " + solve.second._beliefsValuesMap.get(pOMDPFlatMTJ._initBelief));
    }
}
