package libpomdp.solve.online;

import com.jgoodies.forms.layout.FormSpec;
import java.io.PrintStream;
import libpomdp.common.BeliefState;
import libpomdp.common.CustomVector;
import libpomdp.common.Pomdp;
import libpomdp.common.Utils;
import libpomdp.common.ValueFunction;
import org.math.array.DoubleArray;

/* loaded from: input_file:libpomdp/solve/online/AndOrTree.class */
public class AndOrTree extends AbstractAndOrTree {
    protected ExpandHeuristic expH;

    public AndOrTree(Pomdp pomdp, HeuristicSearchOrNode heuristicSearchOrNode, ValueFunction valueFunction, ValueFunction valueFunction2, ExpandHeuristic expandHeuristic) {
        super(pomdp, heuristicSearchOrNode, valueFunction, valueFunction2);
        this.expH = expandHeuristic;
    }

    public void init(BeliefState beliefState) {
        getRoot().init(beliefState, -1, null);
        getRoot().u = getUB().V(beliefState);
        getRoot().l = getLB().V(beliefState);
    }

    public void expand(HeuristicSearchOrNode heuristicSearchOrNode) {
        if (heuristicSearchOrNode.getChildren() != null) {
            System.err.println("node not on fringe");
            return;
        }
        double d = heuristicSearchOrNode.l;
        double d2 = heuristicSearchOrNode.u;
        heuristicSearchOrNode.initChildren(getProblem().nrActions());
        for (int i = 0; i < getProblem().nrActions(); i++) {
            HeuristicSearchAndNode child = heuristicSearchOrNode.getChild(i);
            child.init(i, heuristicSearchOrNode, getProblem().expectedImmediateReward(heuristicSearchOrNode.getBeliefState(), i));
            CustomVector observationProbabilities = getProblem().observationProbabilities(heuristicSearchOrNode.getBeliefState(), i);
            child.initChildren(getProblem().nrObservations(), observationProbabilities);
            for (int i2 = 0; i2 < getProblem().nrObservations(); i2++) {
                HeuristicSearchOrNode child2 = child.getChild(i2);
                if (observationProbabilities.get(i2) != FormSpec.NO_GROW) {
                    child2.init(getProblem().nextBeliefState(heuristicSearchOrNode.getBeliefState(), i, i2), i2, child);
                    child2.getBeliefState().setPoba(observationProbabilities.get(i2));
                    child2.u = getUB().V(child2.getBeliefState());
                    child2.l = getLB().V(child2.getBeliefState());
                    child2.h_b = this.expH.h_b(child2);
                    child2.h_bao = this.expH.h_bao(child2);
                    child2.hStar = child2.h_b;
                    child2.bStar = child2;
                    heuristicSearchOrNode.setSubTreeSize(heuristicSearchOrNode.getSubTreeSize() + 1);
                }
            }
            child.l = ANDpropagateL(child);
            child.u = ANDpropagateU(child);
            child.oStar = this.expH.oStar(child);
            child.hStar = this.expH.hANDStar(child);
            child.bStar = child.getChild(child.oStar).bStar;
        }
        heuristicSearchOrNode.l = ORpropagateL(heuristicSearchOrNode);
        heuristicSearchOrNode.u = ORpropagateU(heuristicSearchOrNode);
        heuristicSearchOrNode.h_b = this.expH.h_b(heuristicSearchOrNode);
        heuristicSearchOrNode.h_ba = this.expH.h_ba(heuristicSearchOrNode);
        heuristicSearchOrNode.aStar = this.expH.aStar(heuristicSearchOrNode);
        heuristicSearchOrNode.hStar = this.expH.hORStar(heuristicSearchOrNode);
        heuristicSearchOrNode.bStar = heuristicSearchOrNode.getChild(heuristicSearchOrNode.aStar).bStar;
        heuristicSearchOrNode.oneStepDeltaLower = heuristicSearchOrNode.l - d;
        heuristicSearchOrNode.oneStepDeltaUpper = heuristicSearchOrNode.u - d2;
        if (heuristicSearchOrNode.oneStepDeltaLower < FormSpec.NO_GROW) {
            System.err.println("Hmmmmmmmmmmm");
        }
    }

    public void updateAncestors(HeuristicSearchOrNode heuristicSearchOrNode) {
        if (null == heuristicSearchOrNode.getChildren()) {
            return;
        }
        int subTreeSize = heuristicSearchOrNode.getSubTreeSize();
        while (heuristicSearchOrNode != getRoot()) {
            HeuristicSearchAndNode parent = heuristicSearchOrNode.getParent();
            parent.l = ANDpropagateL(parent);
            parent.u = ANDpropagateU(parent);
            parent.oStar = this.expH.oStar(parent);
            parent.hStar = this.expH.hANDStar(parent);
            parent.bStar = parent.getChild(parent.oStar).bStar;
            HeuristicSearchOrNode parent2 = parent.getParent();
            parent2.l = ORpropagateL(parent2);
            parent2.u = ORpropagateU(parent2);
            parent2.h_ba = this.expH.h_ba(parent2);
            parent2.aStar = this.expH.aStar(parent2);
            parent2.hStar = parent2.h_ba[parent2.aStar] * parent2.getChild(parent2.aStar).hStar;
            parent2.bStar = parent2.getChild(parent2.aStar).bStar;
            parent2.setSubTreeSize(parent2.getSubTreeSize() + subTreeSize);
            heuristicSearchOrNode = parent2;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double ANDpropagateL(HeuristicSearchAndNode heuristicSearchAndNode) {
        double d = 0.0d;
        for (HeuristicSearchOrNode heuristicSearchOrNode : heuristicSearchAndNode.getChildren()) {
            if (heuristicSearchOrNode != null) {
                d += heuristicSearchOrNode.getBeliefState().getPoba() * heuristicSearchOrNode.l;
            }
        }
        return heuristicSearchAndNode.rba + (getProblem().getGamma() * d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double ANDpropagateU(HeuristicSearchAndNode heuristicSearchAndNode) {
        double d = 0.0d;
        for (HeuristicSearchOrNode heuristicSearchOrNode : heuristicSearchAndNode.getChildren()) {
            if (heuristicSearchOrNode != null) {
                d += heuristicSearchOrNode.getBeliefState().getPoba() * heuristicSearchOrNode.u;
            }
        }
        return heuristicSearchAndNode.rba + (getProblem().getGamma() * d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double ORpropagateL(HeuristicSearchOrNode heuristicSearchOrNode) {
        double d = Double.NEGATIVE_INFINITY;
        for (HeuristicSearchAndNode heuristicSearchAndNode : heuristicSearchOrNode.getChildren()) {
            if (heuristicSearchAndNode.l > d) {
                d = heuristicSearchAndNode.l;
            }
        }
        return Math.max(d, heuristicSearchOrNode.l);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double ORpropagateU(HeuristicSearchOrNode heuristicSearchOrNode) {
        double d = Double.NEGATIVE_INFINITY;
        for (HeuristicSearchAndNode heuristicSearchAndNode : heuristicSearchOrNode.getChildren()) {
            if (heuristicSearchAndNode.u > d) {
                d = heuristicSearchAndNode.u;
            }
        }
        return Math.min(d, heuristicSearchOrNode.u);
    }

    @Override // libpomdp.solve.online.AbstractAndOrTree
    public HeuristicSearchOrNode getRoot() {
        return (HeuristicSearchOrNode) super.getRoot();
    }

    public int currentBestAction() {
        double[] dArr = new double[getProblem().nrActions()];
        for (HeuristicSearchAndNode heuristicSearchAndNode : getRoot().getChildren()) {
            dArr[heuristicSearchAndNode.getAct()] = heuristicSearchAndNode.l;
        }
        return Utils.argmax(dArr);
    }

    public boolean actionIsEpsOptimal(int i, double d) {
        if (Math.abs(getRoot().u - getRoot().l) < d) {
            return true;
        }
        boolean z = true;
        HeuristicSearchAndNode[] children = getRoot().getChildren();
        int length = children.length;
        int i2 = 0;
        while (true) {
            if (i2 >= length) {
                break;
            }
            HeuristicSearchAndNode heuristicSearchAndNode = children[i2];
            if (heuristicSearchAndNode.getAct() != i && getRoot().l < heuristicSearchAndNode.u) {
                z = false;
                break;
            }
            i2++;
        }
        return z;
    }

    public void printdot(String str) {
        HeuristicSearchOrNode root = getRoot();
        PrintStream printStream = null;
        try {
            printStream = new PrintStream(str);
        } catch (Exception e) {
            System.err.println(e.toString());
        }
        printStream.println("strict digraph T {");
        orprint(root, printStream);
        printStream.println("}");
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [double[], double[][]] */
    private void orprint(HeuristicSearchOrNode heuristicSearchOrNode, PrintStream printStream) {
        String str = "b=[\\n " + DoubleArray.toString("%.2f", (double[][]) new double[]{heuristicSearchOrNode.getBeliefState().getPoint().getArray()}) + "]\\n";
        printStream.format(heuristicSearchOrNode.hashCode() + "[label=\"U(b)= %.2f\\nL(b)= %.2f\\nH(b)= %.2f\"];\n", Double.valueOf(heuristicSearchOrNode.u), Double.valueOf(heuristicSearchOrNode.l), Double.valueOf(heuristicSearchOrNode.h_b));
        printStream.println(heuristicSearchOrNode.hashCode() + "->" + heuristicSearchOrNode.bStar.hashCode() + "[label=\"b*\",weight=0,color=blue];");
        if (heuristicSearchOrNode.getChildren() == null) {
            return;
        }
        for (HeuristicSearchAndNode heuristicSearchAndNode : heuristicSearchOrNode.getChildren()) {
            printStream.print(heuristicSearchOrNode.hashCode() + "->" + heuristicSearchAndNode.hashCode() + "[label=\"H(b,a)=" + heuristicSearchOrNode.h_ba[heuristicSearchAndNode.getAct()] + "\"];");
        }
        printStream.println();
        for (HeuristicSearchAndNode heuristicSearchAndNode2 : heuristicSearchOrNode.getChildren()) {
            andprint(heuristicSearchAndNode2, printStream);
        }
    }

    protected void andprint(HeuristicSearchAndNode heuristicSearchAndNode, PrintStream printStream) {
        printStream.format(heuristicSearchAndNode.hashCode() + "[label=\"a=" + getProblem().getActionString(heuristicSearchAndNode.getAct()) + "\\nU(b,a)= %.2f\\nL(b,a)= %.2f\"];\n", Double.valueOf(heuristicSearchAndNode.u), Double.valueOf(heuristicSearchAndNode.l));
        for (HeuristicSearchOrNode heuristicSearchOrNode : heuristicSearchAndNode.getChildren()) {
            if (heuristicSearchOrNode != null) {
                printStream.format(heuristicSearchAndNode.hashCode() + "->" + heuristicSearchOrNode.hashCode() + "[label=\"obs: " + getProblem().getObservationString(heuristicSearchOrNode.getObs()) + "\\nP(o|b,a)= %.2f\\nH(b,a,o)= %.2f\"];\n", Double.valueOf(heuristicSearchOrNode.getBeliefState().getPoba()), Double.valueOf(heuristicSearchOrNode.h_bao));
            }
        }
        printStream.println();
        for (HeuristicSearchOrNode heuristicSearchOrNode2 : heuristicSearchAndNode.getChildren()) {
            if (heuristicSearchOrNode2 != null) {
                orprint(heuristicSearchOrNode2, printStream);
            }
        }
    }
}
