package mgjpomdp.solve;

import gnu.trove.iterator.TObjectDoubleIterator;
import gnu.trove.map.hash.TObjectDoubleHashMap;
import java.text.DecimalFormat;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.PriorityQueue;
import mgjcommon.Pair;
import mgjpomdp.common.AlphaVector;
import mgjpomdp.common.POMDPFlatMTJ;
import mgjpomdp.solve.GapMin;
import mgjpomdp.solve.bounds.FIBMTJ;
import mgjpomdp.solve.bounds.UBData;
import mgjpomdp.solve.bounds.UpperBound;
import no.uib.cipr.matrix.sparse.SparseVector;
import org.antlr.runtime.debug.Profiler;

/* loaded from: input_file:mgjpomdp/solve/UBAug.class */
public class UBAug extends GapMin {
    protected double _Vmin;

    public UBAug(POMDPFlatMTJ pOMDPFlatMTJ) throws Exception {
        super(pOMDPFlatMTJ);
        this._Vmin = pOMDPFlatMTJ.computeVmin();
        this._tolerance = 0.001d;
    }

    public UBData solveUB(int i) throws Exception {
        long currentTimeMillis = System.currentTimeMillis();
        UBData uBData = new UBData();
        FIBMTJ fibmtj = new FIBMTJ();
        fibmtj.solve(this._pomdp, this._tolerance, 0);
        for (double[] dArr : fibmtj._Q) {
            uBData._alphaVectorsSet.add(new AlphaVector(dArr));
        }
        double bound = fibmtj.getBound(this._pomdp._initBelief);
        uBData._beliefsValuesMap.put(this._pomdp._initBelief, bound);
        uBData.reloadArrays();
        logIteration(0, bound, uBData, System.currentTimeMillis() - currentTimeMillis);
        int i2 = 1;
        int length = uBData._beliefs.length;
        double d = Double.POSITIVE_INFINITY;
        while (true) {
            if (i2 != 1 && Math.abs(bound - d) <= this._tolerance) {
                break;
            }
            d = bound;
            suboptimalBeliefsWithPath(this._pomdp._initBelief, uBData, this._tolerance);
            if (uBData._beliefs.length == length) {
                System.out.println("It looks that the solution cannot be improved any further under the given precision of " + this._tolerance);
                break;
            }
            length = uBData._beliefs.length;
            if (i > 0) {
                System.out.println("Current bounds data:");
                uBData.print();
            }
            this._fib_pomdp = POMDPFlatMTJ.constructAugmentedPOMDP(this._pomdp, uBData._beliefs, uBData._values, uBData._alphaVectors);
            fibmtj.solve(this._fib_pomdp, this._tolerance, 0);
            uBData._alphaVectorsSet.clear();
            for (double[] dArr2 : fibmtj._Q) {
                double[] dArr3 = new double[this._pomdp._numS];
                System.arraycopy(dArr2, 0, dArr3, 0, this._pomdp._numS);
                uBData._alphaVectorsSet.add(new AlphaVector(dArr3));
            }
            uBData.reloadAlphaArray();
            for (int i3 = 0; i3 < uBData._beliefs.length; i3++) {
                double d2 = Double.NEGATIVE_INFINITY;
                for (int i4 = 0; i4 < this._pomdp._numA; i4++) {
                    double d3 = fibmtj._Q[i4][this._pomdp._numS + i3];
                    if (d3 > d2) {
                        d2 = d3;
                    }
                }
                uBData._beliefsValuesMap.put(uBData._beliefs[i3], d2);
            }
            uBData.reloadValuesArray();
            bound = uBData._beliefsValuesMap.get(this._pomdp._initBelief);
            int pruneBeliefsUB = pruneBeliefsUB(uBData);
            if (i > 0) {
                System.out.println(pruneBeliefsUB + " beliefs pruned from the upper bound");
            }
            int i5 = i2;
            i2++;
            logIteration(i5, bound, uBData, System.currentTimeMillis() - currentTimeMillis);
        }
        return uBData;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void suboptimalBeliefs(SparseVector sparseVector, UBData uBData, double d) throws Exception {
        int max = (int) Math.max(20.0d, 0.2d * uBData._beliefs.length);
        PriorityQueue priorityQueue = new PriorityQueue();
        HashSet hashSet = new HashSet();
        GapMin.PQEntry pQEntry = new GapMin.PQEntry();
        pQEntry._belief = sparseVector.copy();
        pQEntry._depth = 0;
        pQEntry._ub = UpperBound.getBound(sparseVector, uBData._beliefs, uBData._values, uBData._alphaVectors).first;
        pQEntry._priority = -pQEntry._ub;
        pQEntry._prob = 1.0d;
        priorityQueue.add(pQEntry);
        hashSet.add(pQEntry._belief);
        TObjectDoubleHashMap tObjectDoubleHashMap = new TObjectDoubleHashMap();
        int i = 0;
        SparseVector sparseVector2 = new SparseVector(this._pomdp._numS);
        while (i < max && !priorityQueue.isEmpty()) {
            GapMin.PQEntry pQEntry2 = (GapMin.PQEntry) priorityQueue.remove();
            Pair<Integer, Double> Lookahead = Lookahead(pQEntry2._belief, uBData);
            double doubleValue = Lookahead.second.doubleValue();
            if (((pQEntry2._ub - doubleValue) / (1.0d - this._pomdp._gamma)) * Math.pow(this._pomdp._gamma, pQEntry2._depth) > d && !uBData._beliefsValuesMap.containsKey(pQEntry2._belief) && !this._corners.contains(pQEntry2._belief) && !tObjectDoubleHashMap.containsKey(pQEntry2._belief)) {
                tObjectDoubleHashMap.put(pQEntry2._belief, doubleValue);
                i++;
            }
            int i2 = pQEntry2._depth + 1;
            int intValue = Lookahead.first.intValue();
            for (int i3 = 0; i3 < this._pomdp._numObs; i3++) {
                this._pomdp._TO[intValue][i3].transMult(pQEntry2._belief, sparseVector2);
                double sum = sparseVector2.sum();
                if (sum > 1.0E-8d) {
                    sparseVector2.scale(1.0d / sum);
                    if (!hashSet.contains(sparseVector2)) {
                        double d2 = UpperBound.getBound(sparseVector2, uBData._beliefs, uBData._values, uBData._alphaVectors).first;
                        if ((Math.pow(this._pomdp._gamma, i2) * (d2 - this._Vmin)) / (1.0d - this._pomdp._gamma) > d) {
                            GapMin.PQEntry pQEntry3 = new GapMin.PQEntry();
                            pQEntry3._belief = sparseVector2.copy();
                            pQEntry3._ub = d2;
                            pQEntry3._prob = pQEntry2._prob * sum;
                            pQEntry3._priority = pQEntry3._prob * (-d2) * Math.pow(this._pomdp._gamma, i2);
                            pQEntry3._depth = i2;
                            priorityQueue.add(pQEntry3);
                            hashSet.add(pQEntry3._belief);
                        }
                    }
                }
            }
        }
        TObjectDoubleIterator<K> it = tObjectDoubleHashMap.iterator();
        while (it.hasNext()) {
            it.advance();
            uBData._beliefsValuesMap.put(it.key(), it.value());
        }
        uBData.reloadBeliefArray();
        uBData.reloadValuesArray();
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void suboptimalBeliefsWithPath(SparseVector sparseVector, UBData uBData, double d) throws Exception {
        int max = (int) Math.max(20.0d, 0.2d * uBData._beliefs.length);
        PriorityQueue priorityQueue = new PriorityQueue();
        HashSet hashSet = new HashSet();
        GapMin.PQEntry pQEntry = new GapMin.PQEntry();
        pQEntry._belief = sparseVector.copy();
        pQEntry._depth = 0;
        pQEntry._ub = UpperBound.getBound(sparseVector, uBData._beliefs, uBData._values, uBData._alphaVectors).first;
        pQEntry._priority = -pQEntry._ub;
        pQEntry._prob = 1.0d;
        pQEntry._path = new LinkedList<>();
        priorityQueue.add(pQEntry);
        hashSet.add(pQEntry._belief);
        TObjectDoubleHashMap tObjectDoubleHashMap = new TObjectDoubleHashMap();
        int i = 0;
        SparseVector sparseVector2 = new SparseVector(this._pomdp._numS);
        while (i < max && !priorityQueue.isEmpty()) {
            GapMin.PQEntry pQEntry2 = (GapMin.PQEntry) priorityQueue.remove();
            Pair<Integer, Double> Lookahead = Lookahead(pQEntry2._belief, uBData);
            double doubleValue = Lookahead.second.doubleValue();
            if (((pQEntry2._ub - doubleValue) / (1.0d - this._pomdp._gamma)) * Math.pow(this._pomdp._gamma, pQEntry2._depth) > d) {
                if (!uBData._beliefsValuesMap.containsKey(pQEntry2._belief) && !this._corners.contains(pQEntry2._belief) && !tObjectDoubleHashMap.containsKey(pQEntry2._belief)) {
                    tObjectDoubleHashMap.put(pQEntry2._belief, doubleValue);
                    i++;
                }
                Iterator<GapMin.PathEntry> it = pQEntry2._path.iterator();
                while (it.hasNext()) {
                    GapMin.PathEntry next = it.next();
                    if (!uBData._beliefsValuesMap.containsKey(next._belief) && !this._corners.contains(next._belief) && !tObjectDoubleHashMap.containsKey(next._belief)) {
                        tObjectDoubleHashMap.put(next._belief, next._lookaheadUB);
                        i++;
                    }
                }
            }
            int i2 = pQEntry2._depth + 1;
            int intValue = Lookahead.first.intValue();
            for (int i3 = 0; i3 < this._pomdp._numObs; i3++) {
                this._pomdp._TO[intValue][i3].transMult(pQEntry2._belief, sparseVector2);
                double sum = sparseVector2.sum();
                if (sum > 1.0E-8d) {
                    sparseVector2.scale(1.0d / sum);
                    if (!hashSet.contains(sparseVector2)) {
                        double d2 = UpperBound.getBound(sparseVector2, uBData._beliefs, uBData._values, uBData._alphaVectors).first;
                        if ((Math.pow(this._pomdp._gamma, i2) * (d2 - this._Vmin)) / (1.0d - this._pomdp._gamma) > d) {
                            GapMin.PQEntry pQEntry3 = new GapMin.PQEntry();
                            pQEntry3._belief = sparseVector2.copy();
                            pQEntry3._ub = d2;
                            pQEntry3._prob = pQEntry2._prob * sum;
                            pQEntry3._priority = pQEntry3._prob * (-d2) * Math.pow(this._pomdp._gamma, i2);
                            pQEntry3._depth = i2;
                            pQEntry3.appendToPath(pQEntry2._path);
                            pQEntry3.appendToPath(new GapMin.PathEntry(pQEntry2._belief, doubleValue));
                            priorityQueue.add(pQEntry3);
                            hashSet.add(pQEntry3._belief);
                        }
                    }
                }
            }
        }
        TObjectDoubleIterator<K> it2 = tObjectDoubleHashMap.iterator();
        while (it2.hasNext()) {
            it2.advance();
            uBData._beliefsValuesMap.put(it2.key(), it2.value());
        }
        uBData.reloadBeliefArray();
        uBData.reloadValuesArray();
    }

    public void logIteration(int i, double d, UBData uBData, long j) {
        DecimalFormat decimalFormat = new DecimalFormat("#.######");
        if (i == 0) {
            System.out.println("Iteration\tUB\tNumBeliefs(ub)\tNumAlpha(ub)\tTime[s]");
        }
        System.out.println(i + Profiler.DATA_SEP + decimalFormat.format(d) + Profiler.DATA_SEP + uBData._beliefs.length + Profiler.DATA_SEP + uBData._alphaVectors.length + Profiler.DATA_SEP + decimalFormat.format(j / 1000.0d));
    }

    public static void testUBAug() throws Exception {
        new UBAug(new POMDPFlatMTJ("/home/mgrzes/_data/Cassandra_POMDPs/gapmin_problems/tiger-grid.POMDP", 0)).solveUB(0);
    }

    public static void main(String[] strArr) throws Exception {
        testUBAug();
    }
}
