package mgjpomdp.solve.fsc;

import mgjcommon.Pair;
import mgjpomdp.common.POMDPFlatMTJ;
import mgjpomdp.common.jPOMDPRuntimeConfig;
import mgjpomdp.solve.GapMin;
import mgjpomdp.solve.bounds.FIBMTJ;
import mgjpomdp.solve.bounds.LBData;
import mgjpomdp.solve.bounds.UBData;
import mgjpomdp.solve.bounds.UpperBound;
import mgjpomdp.solve.fsc.IFSCBoundMTJ;
import mgjpomdp.solve.pbvi;
import no.uib.cipr.matrix.sparse.SparseVector;

/* loaded from: input_file:mgjpomdp/solve/fsc/FSCBoundFS.class */
public class FSCBoundFS implements IFSCBoundMTJ {
    public POMDPFlatMTJ _pomdp;
    public int _numA;
    public int _numObs;
    public int _numContrNodes;
    public UBData _gmUBData;
    public LBData _gmLBData;
    public double[][] _lbAlphas;
    public double[][] _ubLongAlphas;
    public Node _root;
    SparseVector _tmpNextBelief;

    /* loaded from: input_file:mgjpomdp/solve/fsc/FSCBoundFS$ContrAssignments.class */
    public class ContrAssignments {
        public ContrAssignments _prev;
        public int _a;
        public int _n;
        public int _o;
        public int _np;

        ContrAssignments(int i, int i2, int i3, int i4, ContrAssignments contrAssignments) {
            this._a = -1;
            this._np = -1;
            this._a = i;
            this._n = i2;
            this._o = i3;
            this._np = i4;
            this._prev = contrAssignments;
        }
    }

    /* loaded from: input_file:mgjpomdp/solve/fsc/FSCBoundFS$Node.class */
    public class Node {
        public NodeType _nodeType;
        public Node[] _children;
        public double[] _child2prob;
        public int _action;
        public double _ub;
        public double _lb;
        public Node _parent;
        public SparseVector _b;
        public Node _bestFringeChild;
        public int _depth;
        public double _prob;
        public NodePriority _priority;
        public boolean _leaf;
        public int _contrNode;
        public ContrAssignments _assignment;

        Node(SparseVector sparseVector, Node node, NodeType nodeType) {
            this._b = sparseVector;
            this._parent = node;
            this._nodeType = nodeType;
        }

        Node(Node node, NodeType nodeType) {
            this._parent = node;
            this._nodeType = nodeType;
        }
    }

    /* loaded from: input_file:mgjpomdp/solve/fsc/FSCBoundFS$NodePriority.class */
    public class NodePriority {
        public double _ubReduction;
        public double _lbReduction;
        public double _prob;

        NodePriority(double d, double d2, double d3) {
            this._ubReduction = d;
            this._lbReduction = d2;
            this._prob = d3;
        }

        int compareTo(NodePriority nodePriority) throws NullPointerException {
            if (this._prob < nodePriority._prob) {
                return -1;
            }
            return this._prob > nodePriority._prob ? 1 : 0;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:mgjpomdp/solve/fsc/FSCBoundFS$NodeType.class */
    public enum NodeType {
        max,
        and
    }

    FSCBoundFS(POMDPFlatMTJ pOMDPFlatMTJ) {
        this._pomdp = pOMDPFlatMTJ;
        this._numA = pOMDPFlatMTJ._numA;
        this._numObs = pOMDPFlatMTJ._numObs;
        this._tmpNextBelief = new SparseVector(pOMDPFlatMTJ._numS);
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public IFSCBoundMTJ.PolicyState getVmaxPolicyState() {
        return null;
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public IFSCBoundMTJ.PolicyState getVminPolicyState() {
        return null;
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public IFSCBoundMTJ.PolicyState getCopyOfPolicyState() {
        return null;
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public IFSCBoundMTJ.PolicyState getCopyOfPolicyStateVOnly() {
        return null;
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public void setInitFromPolicyState(IFSCBoundMTJ.PolicyState policyState) throws Exception {
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public void printPolicyState(IFSCBoundMTJ.PolicyState policyState) {
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public void solve(FSController fSController, double d, int i, double d2, boolean z) throws Exception {
    }

    public void solve(FSController fSController, int i, double d, int i2, double d2, boolean z, boolean z2) throws Exception {
        this._numContrNodes = fSController.getNumNodes();
        if (this._root == null) {
            this._root = createLeaf(fSController, this._pomdp._initBelief, 0, null, 0, 1.0d);
        }
        int i3 = 0;
        if (i2 > 0) {
            System.out.println("current UB(b0)=" + this._root._ub + ", current LB(b0)=" + this._root._lb + " after iterations=0");
        }
        while (true) {
            int i4 = i3;
            i3++;
            if (i4 < 1000 && !stop(d, d2, z, z2)) {
                Node findBestFringeNode = findBestFringeNode(this._root);
                if (i2 > 0) {
                }
                expandNode(findBestFringeNode, fSController, i);
                backPropagate(findBestFringeNode);
                if (i2 > 0) {
                }
                int i42 = i3;
                i3++;
                if (i42 < 1000) {
                    break;
                } else {
                    break;
                }
            }
            break;
        }
        if (i2 > 0) {
            System.out.println("num iterations " + i3);
        }
    }

    protected void expandNode(Node node, FSController fSController, int i) throws Exception {
        node._children = new Node[this._numA];
        node._leaf = false;
        int i2 = fSController._node2action[node._contrNode];
        if (i2 != -1) {
            Node node2 = new Node(node, NodeType.and);
            node2._children = new Node[this._numObs];
            node2._child2prob = new double[this._numObs];
            node2._action = i2;
            node2._prob = node._prob;
            node._children[i2] = node2;
            for (int i3 = 0; i3 < this._numObs; i3++) {
                expandNodesAO(node, fSController, i, i2, i3, node2);
            }
            node2._priority = null;
            node2._ub = node2._parent._b.dot(this._pomdp._R[node2._action]);
            node2._lb = node2._parent._b.dot(this._pomdp._R[node2._action]);
            for (int i4 = 0; i4 < this._numObs; i4++) {
                Node node3 = node2._children[i4];
                if (node3 != null) {
                    node2._ub += this._pomdp._gamma * node2._child2prob[i4] * node3._ub;
                    node2._lb += this._pomdp._gamma * node2._child2prob[i4] * node3._lb;
                    if (node2._priority == null || node3._priority.compareTo(node2._priority) > 0) {
                        node2._priority = node3._priority;
                        node2._bestFringeChild = node3._bestFringeChild;
                    }
                }
            }
        } else {
            for (int i5 = 0; i5 < this._numA; i5++) {
                Node node4 = new Node(node, NodeType.and);
                node4._children = new Node[this._numObs];
                node4._child2prob = new double[this._numObs];
                node4._action = i5;
                node4._prob = node._prob;
                node._children[i5] = node4;
                for (int i6 = 0; i6 < this._numObs; i6++) {
                    expandNodesAO(node, fSController, i, i5, i6, node4);
                }
                node4._priority = null;
                node4._ub = node4._parent._b.dot(this._pomdp._R[node4._action]);
                node4._lb = node4._parent._b.dot(this._pomdp._R[node4._action]);
                for (int i7 = 0; i7 < this._numObs; i7++) {
                    Node node5 = node4._children[i7];
                    if (node5 != null) {
                        node4._ub += this._pomdp._gamma * node4._child2prob[i7] * node5._ub;
                        node4._lb += this._pomdp._gamma * node4._child2prob[i7] * node5._lb;
                        if (node4._priority == null || node5._priority.compareTo(node4._priority) > 0) {
                            node4._priority = node5._priority;
                            node4._bestFringeChild = node5._bestFringeChild;
                        }
                    }
                }
            }
        }
        node._ub = Double.NEGATIVE_INFINITY;
        for (Node node6 : node._children) {
            if (node6 != null && node6._ub > node._ub) {
                node._ub = node6._ub;
                node._bestFringeChild = node6._bestFringeChild;
                node._priority = node6._priority;
            }
        }
        node._lb = Double.NEGATIVE_INFINITY;
        for (Node node7 : node._children) {
            if (node7 != null && node7._lb > node._lb) {
                node._lb = node7._lb;
            }
        }
    }

    protected void expandNodesAO(Node node, FSController fSController, int i, int i2, int i3, Node node2) throws Exception {
        this._pomdp._TO[i2][i3].transMult(node._b, this._tmpNextBelief);
        double sum = this._tmpNextBelief.sum();
        node2._child2prob[i3] = sum;
        if (sum > jPOMDPRuntimeConfig._SPARSITYTHRESHOLD) {
            this._tmpNextBelief.scale(1.0d / sum);
            int i4 = fSController._node2obs2node[node._contrNode][i3];
            if (i4 != -1) {
                node2._children[i3] = createLeaf(fSController, this._tmpNextBelief.copy(), i4, node2, node._depth + 1, node._prob * sum);
                return;
            }
            Node node3 = new Node(this._tmpNextBelief.copy(), node2, NodeType.max);
            node3._depth = node2._depth;
            node3._prob = node2._prob;
            node3._children = new Node[i + 1];
            for (int i5 = 0; i5 <= i; i5++) {
                node3._children[i5] = createLeaf(fSController, this._tmpNextBelief.copy(), i5, node2, node3._depth + 1, node3._prob * sum);
            }
            node2._children[i3] = node3;
            node3._ub = Double.NEGATIVE_INFINITY;
            for (Node node4 : node3._children) {
                if (node4 != null && node4._ub > node3._ub) {
                    node3._ub = node4._ub;
                    node3._bestFringeChild = node4._bestFringeChild;
                    node3._priority = node4._priority;
                }
            }
            node3._lb = Double.NEGATIVE_INFINITY;
            for (Node node5 : node3._children) {
                if (node5 != null && node5._lb > node3._lb) {
                    node3._lb = node5._lb;
                }
            }
        }
    }

    protected Node createLeaf(FSController fSController, SparseVector sparseVector, int i, Node node, int i2, double d) throws Exception {
        Node node2 = new Node(sparseVector, node, NodeType.max);
        node2._depth = i2;
        node2._prob = d;
        node2._leaf = true;
        SparseVector sparseVector2 = UpperBound.getBound(node2._b, this._gmUBData._beliefs, this._gmUBData._values, this._gmUBData._alphaVectors).second;
        double evaluateBeliefPoint = pbvi.evaluateBeliefPoint(sparseVector2, this._ubLongAlphas);
        double evaluateBeliefPoint2 = pbvi.evaluateBeliefPoint(node2._b, this._lbAlphas);
        if (fSController._node2action[i] != -1) {
            node2._ub = sparseVector2.dot(this._ubLongAlphas[fSController._node2action[i]]);
            node2._lb = node2._b.dot(this._lbAlphas[fSController._node2action[i]]);
        } else {
            node2._ub = evaluateBeliefPoint;
            node2._lb = evaluateBeliefPoint2;
        }
        node2._bestFringeChild = node2;
        node2._contrNode = i;
        double pow = node2._prob * Math.pow(this._pomdp._gamma, node2._depth);
        node2._priority = new NodePriority(pow * (evaluateBeliefPoint - node2._ub), pow * (evaluateBeliefPoint2 - node2._lb), pow);
        return node2;
    }

    protected void backPropagate(Node node) {
        Node node2 = node;
        Node node3 = node._parent;
        while (true) {
            Node node4 = node3;
            if (node4 == null) {
                return;
            }
            if (node4._nodeType == NodeType.and) {
                node4._ub = node4._parent._b.dot(this._pomdp._R[node4._action]);
                node4._lb = node4._parent._b.dot(this._pomdp._R[node4._action]);
                for (int i = 0; i < this._numObs; i++) {
                    Node node5 = node4._children[i];
                    if (node5 != null) {
                        node4._ub += this._pomdp._gamma * node4._child2prob[i] * node5._ub;
                        node4._lb += this._pomdp._gamma * node4._child2prob[i] * node5._lb;
                    }
                }
                boolean z = false;
                for (int i2 = 0; i2 < node4._children.length; i2++) {
                    if (!z || node2._priority.compareTo(node4._priority) > 0) {
                        z = true;
                        node4._priority = node2._priority;
                        node4._bestFringeChild = node2._bestFringeChild;
                    }
                }
            } else if (node4._nodeType == NodeType.max) {
                node4._ub = Double.NEGATIVE_INFINITY;
                for (Node node6 : node4._children) {
                    if (node6 != null && node6._ub > node4._ub) {
                        node4._ub = node6._ub;
                        node4._bestFringeChild = node6._bestFringeChild;
                        node4._priority = node6._priority;
                    }
                }
                node4._lb = Double.NEGATIVE_INFINITY;
                for (Node node7 : node4._children) {
                    if (node7 != null && node7._lb > node4._lb) {
                        node4._lb = node7._lb;
                    }
                }
            }
            node2 = node4;
            node3 = node4._parent;
        }
    }

    protected boolean stop(double d, double d2, boolean z, boolean z2) {
        if (this._root._bestFringeChild._prob >= d) {
            return this._root._ub <= d2 && z;
        }
        return true;
    }

    protected Node findBestFringeNode(Node node) throws Exception {
        if (node._bestFringeChild._leaf && node._bestFringeChild._children == null) {
            return node._bestFringeChild;
        }
        throw new Exception("this is not a fringe node, the best fringe node is wrong in the root " + node);
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public boolean solveLPInit(FSController fSController, double d, int i, double d2) throws Exception {
        return false;
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public void solvePS(FSController fSController, double d, int i, double d2, boolean z, int i2) throws Exception {
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public void evaluate(FSController fSController, double d, int i) {
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public double getBound(SparseVector sparseVector) {
        return this._root._ub;
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public int getMaxBoundNode(SparseVector sparseVector) {
        return 0;
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public double getBound(SparseVector sparseVector, int i) {
        return getBound(sparseVector);
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public int getBestAction(SparseVector sparseVector, FSController fSController, int i) throws Exception {
        return 0;
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public FSController greedyController(int i, double d) throws Exception {
        return null;
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public boolean canBeImproved(FSController fSController) {
        return false;
    }

    @Override // mgjpomdp.solve.fsc.IFSCBoundMTJ
    public void setViPruneLB(boolean z) {
    }

    public static void main(String[] strArr) throws Exception {
        POMDPFlatMTJ pOMDPFlatMTJ = new POMDPFlatMTJ("/home/mgrzes/_data/Cassandra_POMDPs/tiger.95.POMDP", 0);
        if (pOMDPFlatMTJ._gamma == 1.0d) {
            pOMDPFlatMTJ._gamma = 0.999d;
            System.err.println("pomdp._gamma set to " + pOMDPFlatMTJ._gamma);
        }
        GapMin gapMin = new GapMin(pOMDPFlatMTJ);
        Pair<LBData, UBData> solve = gapMin.solve(0);
        FSCBoundFS fSCBoundFS = new FSCBoundFS(pOMDPFlatMTJ);
        fSCBoundFS._gmLBData = solve.first;
        fSCBoundFS._gmUBData = solve.second;
        fSCBoundFS._lbAlphas = solve.first._alphaVectors;
        fSCBoundFS._ubLongAlphas = gapMin._lastLongAlphas;
        POMDPFlatMTJ pOMDPFlatMTJ2 = gapMin._fib_pomdp;
        FSCBoundFIBMTJ fSCBoundFIBMTJ = new FSCBoundFIBMTJ(pOMDPFlatMTJ2, 5);
        FIBMTJ fibmtj = new FIBMTJ();
        fibmtj.solve(pOMDPFlatMTJ2, 1.0E-6d, 0);
        fSCBoundFS._ubLongAlphas = fibmtj._Q;
        FSController fSController = new FSController(5, 2);
        fSController._node2action[0] = 0;
        fSController._node2action[1] = 0;
        fSController._node2action[2] = 1;
        fSController._node2action[3] = 0;
        fSController._node2obs2node[0][0] = 1;
        fSController._node2obs2node[0][1] = 2;
        fSController._node2obs2node[1][0] = 3;
        fSController.print_dot("/tmp/search", pOMDPFlatMTJ._observationToName, pOMDPFlatMTJ._actionToName);
        Runtime.getRuntime().exec(new String[]{"/usr/bin/display", "/tmp/search.svg"});
        long currentTimeMillis = System.currentTimeMillis();
        fSCBoundFIBMTJ.solve(fSController, 0.001d, 0, Double.POSITIVE_INFINITY, false);
        double bound = fSCBoundFIBMTJ.getBound(pOMDPFlatMTJ2._initBelief, 0);
        System.out.println((System.currentTimeMillis() - currentTimeMillis) / 1000.0d);
        long currentTimeMillis2 = System.currentTimeMillis();
        fSCBoundFS.solve(fSController, 4, 0.001d, 1, Double.POSITIVE_INFINITY, false, false);
        double bound2 = fSCBoundFS.getBound(pOMDPFlatMTJ._initBelief, 0);
        System.out.println((System.currentTimeMillis() - currentTimeMillis2) / 1000.0d);
        System.out.println("fibUB = " + bound + ", fsUB = " + bound2);
    }
}
