package mgjpomdp.solve.fsc;

import mgjcommon.CHeurSearchStats;
import mgjpomdp.common.MDPUtils;
import mgjpomdp.common.POMDPFlatMTJ;
import mgjpomdp.common.jPOMDPRuntimeConfig;

/* loaded from: input_file:mgjpomdp/solve/fsc/SearchSnfMTJ.class */
public class SearchSnfMTJ extends Search {
    public FSController _currController;
    public FSController _bestController;
    protected double _lb;
    public boolean _pruneOneActionControllers;
    public boolean _detectJointEdges;
    CHeurSearchStats _tmpStats;
    int _tmpNextAvailableAction;

    public SearchSnfMTJ(double d, POMDPFlatMTJ pOMDPFlatMTJ, int i, FSCHeuristicType fSCHeuristicType) throws Exception {
        this._gammaRatio = d;
        this._pomdp = pOMDPFlatMTJ;
        if (this._gammaRatio != -1.0d) {
            this._originalGamma = this._pomdp._gamma;
            this._pomdp._gamma *= this._gammaRatio;
            System.out.println("Temporal gamma " + this._pomdp._gamma);
        }
        this._numNodes = i;
        switch (fSCHeuristicType) {
            case qmdp:
                this._bound = new FSCBoundQMDPMTJ(this._pomdp, this._numNodes);
                return;
            case fib:
                this._bound = new FSCBoundFIBMTJ(this._pomdp, this._numNodes);
                return;
            default:
                throw new Exception("ERROR: unknown heuristic type " + fSCHeuristicType);
        }
    }

    public FSController search(FSController fSController, int i, CHeurSearchStats cHeurSearchStats) throws Exception {
        this._tmpStats = cHeurSearchStats;
        this._tmpStats._nStates = MDPUtils.dfsTreeSize(this._pomdp._numA, this._pomdp._numObs, this._numNodes);
        this._bestController = fSController;
        double d = MDPUtils.tolerance(this._pomdp.computeVmax(), this._pomdp.computeVmin(), this._pomdp._gamma, FSCParams.SIGNIFICANT_DIGITS);
        this._lb = improveInitController(fSController, Math.min(d, Math.pow(0.1d, FSCParams.SIGNIFICANT_DIGITS + 1)) / 2.0d);
        if (i > 0) {
            System.out.println("Initial lower bound is: " + this._lb);
        }
        this._currController = new FSController(this._numNodes, this._pomdp._numObs);
        this._tmpNextAvailableAction = 0;
        bbdfs_actions(0, this._pomdp.computeVmax());
        if (this._gammaRatio != -1.0d) {
            this._pomdp._gamma = this._originalGamma;
            double pow = Math.pow(10.0d, (-FSCParams.SIGNIFICANT_DIGITS) - 1);
            if (d > pow) {
                d = pow;
            }
            this._bound.solve(this._bestController, d / 2.0d, 0, Double.POSITIVE_INFINITY, false);
            this._lb = this._bound.getBound(this._pomdp._initBelief);
        }
        if (i > 0) {
            System.out.println("Final lower bound is: " + this._lb);
        }
        return this._bestController;
    }

    protected void bbdfs_actions(int i, double d) throws Exception {
        if (i >= this._numNodes) {
            if (this._pruneOneActionControllers && this._currController._node2action[0] == this._currController._node2action[this._numNodes - 1]) {
                return;
            }
            bbdfs_edges(0, 0, d);
            return;
        }
        int i2 = this._tmpNextAvailableAction;
        for (int i3 = this._tmpNextAvailableAction; i3 < this._pomdp._numA; i3++) {
            if (i3 > this._tmpNextAvailableAction) {
                this._tmpNextAvailableAction = i3;
            }
            this._currController._node2action[i] = i3;
            this._bound.solve(this._currController, MDPUtils.tolerance(d, this._lb, this._pomdp._gamma, FSCParams.SIGNIFICANT_DIGITS), 0, Double.POSITIVE_INFINITY, false);
            this._tmpStats._nHeurEvalsOfStates++;
            double bound = this._bound.getBound(this._pomdp._initBelief);
            if (bound > this._lb) {
                bbdfs_actions(i + 1, bound);
            }
            this._tmpNextAvailableAction = i2;
            this._currController._node2action[i] = -1;
        }
    }

    protected void bbdfs_edges(int i, int i2, double d) throws Exception {
        for (int i3 = 0; i3 < this._numNodes; i3++) {
            this._currController._node2obs2node[i][i2] = i3;
            double d2 = MDPUtils.tolerance(d, this._lb, this._pomdp._gamma, FSCParams.SIGNIFICANT_DIGITS);
            this._bound.solve(this._currController, d2, 0, Double.POSITIVE_INFINITY, false);
            this._tmpStats._nHeurEvalsOfStates++;
            double bound = this._bound.getBound(this._pomdp._initBelief);
            if (bound > this._lb) {
                if (i2 != this._pomdp._numObs - 1) {
                    bbdfs_edges(i, i2 + 1, bound);
                } else if (i == this._numNodes - 1) {
                    System.out.println("The lower bound improved (tolerance=" + d2 + ") from " + this._lb + " to " + bound + " after => " + this._tmpStats._nHeurEvalsOfStates + " <= evals, n=" + this._currController.canPartialBeFullySpecified().second);
                    this._lb = bound;
                    this._bestController.reload(this._currController);
                    this._bestController.print_dot(jPOMDPRuntimeConfig._resultsPrefix + "_lastImpr", this._pomdp._observationToName, this._pomdp._actionToName);
                    this._tmpStats._nImproved++;
                } else {
                    bbdfs_edges(i + 1, 0, bound);
                }
            }
        }
        this._currController._node2obs2node[i][i2] = -1;
    }

    public static void main(String[] strArr) throws Exception {
        long currentTimeMillis = System.currentTimeMillis();
        System.out.println("\nStarted...\n");
        POMDPFlatMTJ pOMDPFlatMTJ = new POMDPFlatMTJ("/home/mgrzes/_data/Cassandra_POMDPs/tiger.95.POMDP", 0);
        System.out.println(pOMDPFlatMTJ.toString());
        SearchSnfMTJ searchSnfMTJ = new SearchSnfMTJ(-1.0d, pOMDPFlatMTJ, 5, FSCHeuristicType.qmdp);
        searchSnfMTJ._pruneOneActionControllers = true;
        CHeurSearchStats cHeurSearchStats = new CHeurSearchStats();
        FSController fSController = new FSController(5, pOMDPFlatMTJ._numObs);
        fSController.randomInit(pOMDPFlatMTJ._numA);
        searchSnfMTJ.search(fSController, 1, cHeurSearchStats).print_dot("/tmp/bbdfs_controller", pOMDPFlatMTJ._observationToName, pOMDPFlatMTJ._actionToName);
        System.out.println("Total time: " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d) + " seconds.");
        System.out.println("Num Heuristic Evaluations: " + cHeurSearchStats._nHeurEvalsOfStates);
        System.out.println("Num Solutions Improved: " + cHeurSearchStats._nImproved);
        System.out.println("Num of All States without any symmetry reduction: " + cHeurSearchStats._nStates.toString());
    }
}
