package mgjpomdp.solve.fsc;

import gnu.trove.set.hash.TIntHashSet;
import mgjcommon.CHeurSearchStats;
import mgjpomdp.common.MDPUtils;
import mgjpomdp.common.POMDPFlatMTJ;
import mgjpomdp.solve.GapMin;
import no.uib.cipr.matrix.sparse.SparseVector;

/* loaded from: input_file:mgjpomdp/solve/fsc/SearchStartFixedLbViAugMTJ.class */
public class SearchStartFixedLbViAugMTJ extends SearchStartFixedLbViMTJ {
    private POMDPFlatMTJ _nonAugPomdp;

    public SearchStartFixedLbViAugMTJ(double d, POMDPFlatMTJ pOMDPFlatMTJ, int i, FSCHeuristicType fSCHeuristicType) throws Exception {
        this._gammaRatio = d;
        this._nonAugPomdp = pOMDPFlatMTJ;
        if (this._gammaRatio != -1.0d) {
            this._originalGamma = pOMDPFlatMTJ._gamma;
            pOMDPFlatMTJ._gamma *= this._gammaRatio;
            System.out.println("Temporal gamma " + pOMDPFlatMTJ._gamma);
        }
        GapMin gapMin = new GapMin(pOMDPFlatMTJ);
        gapMin._timeLimitSeconds = FSCParams._GapMin_Max_Time;
        this._gapminPolicy = gapMin.solve(0);
        this._longUBAlphas = gapMin._lastLongAlphas;
        this._referenceUB = this._gapminPolicy.second._ub;
        SparseVector[] sparseVectorArr = this._gapminPolicy.second._beliefs;
        double[] dArr = this._gapminPolicy.second._values;
        double[][] dArr2 = this._gapminPolicy.second._alphaVectors;
        POMDPFlatMTJ pOMDPFlatMTJ2 = gapMin._fib_pomdp;
        if (pOMDPFlatMTJ2 != null && gapMin._fib_pomdp._numS > pOMDPFlatMTJ._numS + this._gapminPolicy.second._beliefs.length) {
            pOMDPFlatMTJ2 = POMDPFlatMTJ.constructAugmentedPOMDP(pOMDPFlatMTJ, sparseVectorArr, dArr, dArr2);
        }
        this._pomdp = pOMDPFlatMTJ2 == null ? POMDPFlatMTJ.constructAugmentedPOMDP(pOMDPFlatMTJ, sparseVectorArr, dArr, dArr2) : pOMDPFlatMTJ2;
        this._numNodes = i;
        switch (fSCHeuristicType) {
            case qmdp:
                this._bound = new FSCBoundQMDPMTJ(this._pomdp, this._numNodes);
                break;
            case fib:
                this._bound = new FSCBoundFIBMTJ(this._pomdp, this._numNodes);
                break;
            default:
                throw new Exception("ERROR: unknown heuristic type " + fSCHeuristicType);
        }
        this._tolerance = MDPUtils.tolerance(this._gapminPolicy.second._ub, this._gapminPolicy.first._lb, this._pomdp._gamma, FSCParams.SIGNIFICANT_DIGITS);
        this._pruning = SearchPruningType.actionsOnly;
        this._detectJointEdges = false;
        this._tmpActions2NumNodes = new int[this._pomdp._numA];
        this._VMinPolicyState = this._bound.getVminPolicyState();
    }

    @Override // mgjpomdp.solve.fsc.Search
    public POMDPFlatMTJ getOriginalPOMDP() {
        return this._nonAugPomdp;
    }

    @Override // mgjpomdp.solve.fsc.SearchStartFixedLbViMTJ
    protected void initActionsWithJointEdges() {
        this._actionsWithJointEdges = new TIntHashSet();
        for (int i = 0; i < this._nonAugPomdp._numA; i++) {
            if (Math.abs(this._nonAugPomdp._O[i].minValue() - this._nonAugPomdp._O[i].maxValue()) < 1.0E-9d) {
                this._actionsWithJointEdges.add(i);
            }
        }
    }

    @Override // mgjpomdp.solve.fsc.SearchStartFixed
    public void checkUB(double d, FSController fSController) throws Exception {
        if (d - this._referenceUB > this._tolerance) {
            System.err.println("ub is larger than gap min ub by " + (d - this._referenceUB));
        }
    }

    public static void main(String[] strArr) throws Exception {
        help(strArr);
        parseCommandLine(strArr);
        long currentTimeMillis = System.currentTimeMillis();
        System.out.println("\nStarted...\n");
        POMDPFlatMTJ pOMDPFlatMTJ = new POMDPFlatMTJ(_pomdpPath, 0);
        if (pOMDPFlatMTJ._numS <= 10) {
            System.out.println(pOMDPFlatMTJ.toString());
        } else {
            System.out.println(pOMDPFlatMTJ.toStr());
        }
        if (pOMDPFlatMTJ._gamma == -1.0d || pOMDPFlatMTJ._gamma == 1.0d) {
            System.out.println("gamma was not set in this POMDP or was set to 1.0, using 0.999 by default\n");
            pOMDPFlatMTJ._gamma = 0.999d;
        }
        if (pOMDPFlatMTJ._gamma < _initGamma) {
            throw new Exception("gamma for initialisation should be lower than gamma in the POMDP");
        }
        FSCParams.setSigDigits(3);
        SearchStartFixedLbViAugMTJ searchStartFixedLbViAugMTJ = new SearchStartFixedLbViAugMTJ(-1.0d, pOMDPFlatMTJ, _requriredNumNodes, FSCHeuristicType.fib);
        System.out.println("BB object created in: " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d) + " seconds.");
        searchStartFixedLbViAugMTJ._pruneOneActionControllers = true;
        searchStartFixedLbViAugMTJ._pruning = SearchPruningType.stringSameColour;
        searchStartFixedLbViAugMTJ._detectJointEdges = true;
        searchStartFixedLbViAugMTJ._bound.setViPruneLB(true);
        CHeurSearchStats cHeurSearchStats = new CHeurSearchStats();
        FSController fSController = new FSController(_requriredNumNodes, pOMDPFlatMTJ._numObs);
        fSController.randomInit(pOMDPFlatMTJ._numA);
        double d = searchStartFixedLbViAugMTJ._pomdp._gamma;
        searchStartFixedLbViAugMTJ._pomdp._gamma = _initGamma;
        CHeurSearchStats cHeurSearchStats2 = new CHeurSearchStats();
        FSController search = searchStartFixedLbViAugMTJ.search(fSController, 1, cHeurSearchStats2);
        System.out.println("Num Heuristic Evaluations for initial controller: " + cHeurSearchStats2._nHeurEvalsOfStates);
        System.out.println("Init time: " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d) + " seconds.");
        searchStartFixedLbViAugMTJ._pomdp._gamma = d;
        search.print_dot(_resultsPrefix + "bbdfs_initial", pOMDPFlatMTJ._observationToName, pOMDPFlatMTJ._actionToName);
        FSCParams.setSigDigits(3);
        FSController search2 = searchStartFixedLbViAugMTJ.search(search, 1, cHeurSearchStats);
        long currentTimeMillis2 = System.currentTimeMillis();
        search2.print_dot(_resultsPrefix + "bbdfs_controller", pOMDPFlatMTJ._observationToName, pOMDPFlatMTJ._actionToName);
        System.out.println("Total time: " + ((currentTimeMillis2 - currentTimeMillis) / 1000.0d) + " seconds.");
        System.out.println("Num Heuristic Evaluations: " + cHeurSearchStats._nHeurEvalsOfStates);
        System.out.println("Num Solutions Improved: " + cHeurSearchStats._nImproved);
        System.out.println("Num of All States without any symmetry reduction: " + cHeurSearchStats._nStates.toString());
    }
}
