package libpomdp.solve.hybrid;

import com.jgoodies.forms.layout.FormSpec;
import java.io.PrintStream;
import libpomdp.common.CustomVector;
import libpomdp.common.Pomdp;
import libpomdp.common.Utils;
import libpomdp.common.ValueFunction;
import libpomdp.common.add.PomdpAdd;
import libpomdp.common.add.ValueFunctionAdd;
import libpomdp.common.add.symbolic.DD;
import libpomdp.common.add.symbolic.DDleaf;
import libpomdp.common.add.symbolic.OP;
import libpomdp.solve.online.AndOrTree;
import libpomdp.solve.online.ExpandHeuristic;
import libpomdp.solve.online.OrNode;
import org.math.array.DoubleArray;

/* loaded from: input_file:libpomdp/solve/hybrid/AndOrTreeUpdateAdd.class */
public class AndOrTreeUpdateAdd extends AndOrTree {
    private BackupHeuristic bakH;
    public CustomVector treeSupportSetSize;

    public AndOrTreeUpdateAdd(Pomdp pomdp, HybridValueIterationOrNode hybridValueIterationOrNode, ValueFunction valueFunction, ValueFunction valueFunction2, ExpandHeuristic expandHeuristic, BackupHeuristic backupHeuristic) {
        super(pomdp, hybridValueIterationOrNode, valueFunction, valueFunction2, expandHeuristic);
        this.bakH = backupHeuristic;
        this.treeSupportSetSize = new CustomVector(getLB().size());
        this.treeSupportSetSize.zero();
    }

    public void expand(HybridValueIterationOrNode hybridValueIterationOrNode) {
        if (hybridValueIterationOrNode.getChildren() != null) {
            System.err.println("node cannot be expanded, it already has children");
            return;
        }
        double d = hybridValueIterationOrNode.l;
        double d2 = hybridValueIterationOrNode.u;
        hybridValueIterationOrNode.initChildren(getProblem().nrActions());
        for (int i = 0; i < getProblem().nrActions(); i++) {
            HybridValueIterationAndNode child = hybridValueIterationOrNode.getChild(i);
            child.init(i, hybridValueIterationOrNode, getProblem().expectedImmediateReward(hybridValueIterationOrNode.getBeliefState(), i));
            CustomVector observationProbabilities = getProblem().observationProbabilities(hybridValueIterationOrNode.getBeliefState(), i);
            child.initChildren(getProblem().nrObservations(), observationProbabilities);
            for (int i2 = 0; i2 < getProblem().nrObservations(); i2++) {
                HybridValueIterationOrNode child2 = child.getChild(i2);
                if (observationProbabilities.get(i2) == FormSpec.NO_GROW) {
                    System.err.println("SMTHIHNGS WRRRRONG");
                } else {
                    child2.init(getProblem().nextBeliefState(hybridValueIterationOrNode.getBeliefState(), i, i2), i2, child);
                    child2.getBeliefState().setPoba(observationProbabilities.get(i2));
                    child2.u = getUB().V(child2.getBeliefState());
                    child2.l = getLB().V(child2.getBeliefState());
                    child.validPlanid = child2.getBeliefState().getAlphaVectorIndex();
                    child2.h_b = this.expH.h_b(child2);
                    child2.h_bao = this.expH.h_bao(child2);
                    child2.hStar = child2.h_b;
                    child2.bStar = child2;
                    hybridValueIterationOrNode.setSubTreeSize(hybridValueIterationOrNode.getSubTreeSize() + 1);
                    this.treeSupportSetSize.set(child2.getBeliefState().getAlphaVectorIndex(), this.treeSupportSetSize.get(child2.getBeliefState().getAlphaVectorIndex()) + 1.0d);
                }
            }
            child.l = ANDpropagateL(child);
            child.u = ANDpropagateU(child);
            child.oStar = this.expH.oStar(child);
            child.hStar = this.expH.hANDStar(child);
            child.bStar = child.getChild(child.oStar).bStar;
            child.bakHeuristicStar = new double[getLB().size()];
            child.bakCandidate = new HybridValueIterationOrNode[getLB().size()];
        }
        hybridValueIterationOrNode.l = ORpropagateLexpand(hybridValueIterationOrNode);
        hybridValueIterationOrNode.u = ORpropagateU(hybridValueIterationOrNode);
        hybridValueIterationOrNode.h_b = this.expH.h_b(hybridValueIterationOrNode);
        hybridValueIterationOrNode.h_ba = this.expH.h_ba(hybridValueIterationOrNode);
        hybridValueIterationOrNode.aStar = this.expH.aStar(hybridValueIterationOrNode);
        hybridValueIterationOrNode.hStar = this.expH.hORStar(hybridValueIterationOrNode);
        hybridValueIterationOrNode.bStar = hybridValueIterationOrNode.getChild(hybridValueIterationOrNode.aStar).bStar;
        hybridValueIterationOrNode.oneStepDeltaLower = hybridValueIterationOrNode.l - d;
        hybridValueIterationOrNode.oneStepDeltaUpper = hybridValueIterationOrNode.u - d2;
        if (hybridValueIterationOrNode.oneStepDeltaLower < FormSpec.NO_GROW) {
            System.err.println("Hmmmmmmmmmmm");
        }
        hybridValueIterationOrNode.bakHeuristic = this.bakH.h_b(hybridValueIterationOrNode);
        hybridValueIterationOrNode.bakHeuristicStar = new CustomVector(getLB().size());
        hybridValueIterationOrNode.bakHeuristicStar.zero();
        hybridValueIterationOrNode.bakCandidate = new HybridValueIterationOrNode[getLB().size()];
        hybridValueIterationOrNode.bakHeuristicStar.set(hybridValueIterationOrNode.getBeliefState().getAlphaVectorIndex(), hybridValueIterationOrNode.bakHeuristic);
        hybridValueIterationOrNode.bakCandidate[hybridValueIterationOrNode.getBeliefState().getAlphaVectorIndex()] = hybridValueIterationOrNode;
    }

    public void updateAncestors(HybridValueIterationOrNode hybridValueIterationOrNode) {
        if (null == hybridValueIterationOrNode.getChildren()) {
            return;
        }
        int subTreeSize = hybridValueIterationOrNode.getSubTreeSize();
        while (hybridValueIterationOrNode != getRoot()) {
            HybridValueIterationAndNode parent = hybridValueIterationOrNode.getParent();
            parent.l = ANDpropagateL(parent);
            parent.u = ANDpropagateU(parent);
            parent.oStar = this.expH.oStar(parent);
            parent.hStar = this.expH.hANDStar(parent);
            parent.bStar = parent.getChild(parent.oStar).bStar;
            for (int i = 0; i < getLB().size(); i++) {
                parent.bakCandidate[i] = this.bakH.updateBakStar(parent, hybridValueIterationOrNode.getObs(), i);
            }
            HybridValueIterationOrNode parent2 = parent.getParent();
            parent2.l = ORpropagateL(parent2);
            parent2.u = ORpropagateU(parent2);
            parent2.h_ba = this.expH.h_ba(parent2);
            parent2.aStar = this.expH.aStar(parent2);
            parent2.hStar = parent2.h_ba[parent2.aStar] * parent2.getChild(parent2.aStar).hStar;
            parent2.bStar = parent2.getChild(parent2.aStar).bStar;
            for (int i2 = 0; i2 < getLB().size(); i2++) {
                parent2.bakCandidate[i2] = this.bakH.updateBakStar(parent2, parent.getAct(), i2);
            }
            parent2.setSubTreeSize(parent2.getSubTreeSize() + subTreeSize);
            hybridValueIterationOrNode = hybridValueIterationOrNode.getParent().getParent();
        }
    }

    @Override // libpomdp.solve.online.AbstractAndOrTree
    public void moveTree(OrNode orNode) {
        super.moveTree(orNode);
        this.treeSupportSetSize.zero();
    }

    @Override // libpomdp.solve.online.AbstractAndOrTree
    public PomdpAdd getProblem() {
        return (PomdpAdd) super.getProblem();
    }

    protected double ORpropagateLexpand(HybridValueIterationOrNode hybridValueIterationOrNode) {
        double[] dArr = new double[getProblem().nrActions()];
        for (HybridValueIterationAndNode hybridValueIterationAndNode : hybridValueIterationOrNode.getChildren()) {
            dArr[hybridValueIterationAndNode.getAct()] = hybridValueIterationAndNode.l;
        }
        hybridValueIterationOrNode.oneStepBestAction = Utils.argmax(dArr);
        return Math.max(dArr[hybridValueIterationOrNode.oneStepBestAction], hybridValueIterationOrNode.l);
    }

    public double[] backupLowerAtRoot() {
        DD myNew = DDleaf.myNew(getProblem().getGamma());
        DD dd = DD.zero;
        int currentBestAction = currentBestAction();
        DD[] ddArr = ((ValueFunctionAdd) getLB()).getvAdd();
        for (HybridValueIterationOrNode hybridValueIterationOrNode : getRoot().getChild(currentBestAction).getChildren()) {
            dd = OP.add(dd, getProblem().gao(ddArr[hybridValueIterationOrNode.getBeliefState().getAlphaVectorIndex()], currentBestAction, hybridValueIterationOrNode.getObs()));
        }
        return OP.convert2array(OP.add(getProblem().R[currentBestAction], OP.mult(myNew, dd)), getProblem().getstaIds());
    }

    public ValueFunction backupLowerAtNode(HybridValueIterationOrNode hybridValueIterationOrNode) {
        if (null == hybridValueIterationOrNode.getChildren()) {
            System.err.println("Attempted to backup a fringe node");
            return null;
        }
        DD myNew = DDleaf.myNew(getProblem().getGamma());
        DD dd = DD.zero;
        int i = 0;
        DD[] ddArr = ((ValueFunctionAdd) getLB()).getvAdd();
        HybridValueIterationOrNode[] children = hybridValueIterationOrNode.getChild(hybridValueIterationOrNode.oneStepBestAction).getChildren();
        int length = children.length;
        for (int i2 = 0; i2 < length; i2++) {
            HybridValueIterationOrNode hybridValueIterationOrNode2 = children[i2];
            dd = hybridValueIterationOrNode2 == null ? OP.add(dd, getProblem().gao(ddArr[hybridValueIterationOrNode.getChild(hybridValueIterationOrNode.oneStepBestAction).validPlanid], hybridValueIterationOrNode.oneStepBestAction, i)) : OP.add(dd, getProblem().gao(ddArr[hybridValueIterationOrNode2.getBeliefState().getAlphaVectorIndex()], hybridValueIterationOrNode.oneStepBestAction, i));
            i++;
        }
        ValueFunctionAdd valueFunctionAdd = new ValueFunctionAdd(Utils.append(ddArr, OP.add(getProblem().R[hybridValueIterationOrNode.oneStepBestAction], OP.mult(myNew, dd))), getProblem().getstaIds(), Utils.horzCat(getLB().getActions(), hybridValueIterationOrNode.oneStepBestAction));
        setLB(valueFunctionAdd);
        return valueFunctionAdd;
    }

    @Override // libpomdp.solve.online.AndOrTree, libpomdp.solve.online.AbstractAndOrTree
    public HybridValueIterationOrNode getRoot() {
        return (HybridValueIterationOrNode) super.getRoot();
    }

    public double expectedReuse() {
        double d = 0.0d;
        for (HybridValueIterationOrNode hybridValueIterationOrNode : getRoot().getChild(currentBestAction()).getChildren()) {
            if (hybridValueIterationOrNode != null) {
                d += hybridValueIterationOrNode.getBeliefState().getPoba() * hybridValueIterationOrNode.getSubTreeSize();
            }
        }
        return d;
    }

    public double expectedReuseRatio() {
        return expectedReuse() / (getRoot().getSubTreeSize() - (getProblem().nrObservations() * getProblem().nrActions()));
    }

    @Override // libpomdp.solve.online.AndOrTree
    public void printdot(String str) {
        HybridValueIterationOrNode root = getRoot();
        PrintStream printStream = null;
        try {
            printStream = new PrintStream(str);
        } catch (Exception e) {
            System.err.println(e.toString());
        }
        printStream.println("digraph T {");
        orprint(root, printStream);
        printStream.println("}");
    }

    /* JADX WARN: Type inference failed for: r2v58, types: [double[], double[][]] */
    private void orprint(HybridValueIterationOrNode hybridValueIterationOrNode, PrintStream printStream) {
        if (getProblem().nrStates() < 4) {
            String str = "b=[" + DoubleArray.toString("%.2f", (double[][]) new double[]{hybridValueIterationOrNode.getBeliefState().getPoint().getArray()}) + "]\\n";
        }
        printStream.format(hybridValueIterationOrNode.hashCode() + "[label=\"U(b)= %.2f\\nL(b)= %.2f\\nexpH(b)= %.2f\\nexpH*(b)= %.2f\\nbakH(b)= %.2f\"];\n", Double.valueOf(hybridValueIterationOrNode.u), Double.valueOf(hybridValueIterationOrNode.l), Double.valueOf(hybridValueIterationOrNode.h_b), Double.valueOf(hybridValueIterationOrNode.hStar), Double.valueOf(hybridValueIterationOrNode.bakHeuristic));
        printStream.println(hybridValueIterationOrNode.hashCode() + "->" + hybridValueIterationOrNode.bStar.hashCode() + "[label=\"b*\",weight=0,color=blue];");
        if (hybridValueIterationOrNode == getRoot()) {
            System.err.println("lenght is" + this.treeSupportSetSize.size());
            double[] dArr = new double[this.treeSupportSetSize.size()];
            for (int i = 0; i < this.treeSupportSetSize.size(); i++) {
                dArr[i] = this.treeSupportSetSize.get(i) / hybridValueIterationOrNode.getSubTreeSize();
                System.err.println(dArr[i]);
            }
            double[] dArr2 = new double[this.treeSupportSetSize.size()];
            for (int i2 = 0; i2 < this.treeSupportSetSize.size(); i2++) {
                dArr2[i2] = hybridValueIterationOrNode.bakHeuristicStar.get(i2) * dArr[i2];
                System.err.println(dArr2[i2]);
            }
            int argmax = Utils.argmax(dArr2);
            System.err.println(argmax);
            if (dArr2[argmax] > FormSpec.NO_GROW) {
                printStream.println(hybridValueIterationOrNode.hashCode() + "->" + hybridValueIterationOrNode.bakCandidate[argmax].hashCode() + "[label=\"bakCandidate\",weight=0,color=orange];");
            }
        }
        if (hybridValueIterationOrNode.getChildren() == null) {
            return;
        }
        for (HybridValueIterationAndNode hybridValueIterationAndNode : hybridValueIterationOrNode.getChildren()) {
            printStream.print(hybridValueIterationOrNode.hashCode() + "->" + hybridValueIterationAndNode.hashCode() + "[label=\"H(b,a)=" + hybridValueIterationOrNode.h_ba[hybridValueIterationAndNode.getAct()] + "\"];");
        }
        printStream.println();
        for (HybridValueIterationAndNode hybridValueIterationAndNode2 : hybridValueIterationOrNode.getChildren()) {
            andprint(hybridValueIterationAndNode2, printStream);
        }
    }

    protected void andprint(HybridValueIterationAndNode hybridValueIterationAndNode, PrintStream printStream) {
        printStream.format(hybridValueIterationAndNode.hashCode() + "[label=\"a=" + getProblem().getActionString(hybridValueIterationAndNode.getAct()) + "\\nU(b,a)= %.2f\\nL(b,a)= %.2f\"];\n", Double.valueOf(hybridValueIterationAndNode.u), Double.valueOf(hybridValueIterationAndNode.l));
        for (HybridValueIterationOrNode hybridValueIterationOrNode : hybridValueIterationAndNode.getChildren()) {
            if (hybridValueIterationOrNode != null) {
                printStream.format(hybridValueIterationAndNode.hashCode() + "->" + hybridValueIterationOrNode.hashCode() + "[label=\"obs: " + getProblem().getObservationString(hybridValueIterationOrNode.getObs()) + "\\nP(o|b,a)= %.2f\\nH(b,a,o)= %.2f\"];", Double.valueOf(hybridValueIterationOrNode.getBeliefState().getPoba()), Double.valueOf(hybridValueIterationOrNode.h_bao));
            }
        }
        printStream.println();
        for (HybridValueIterationOrNode hybridValueIterationOrNode2 : hybridValueIterationAndNode.getChildren()) {
            if (hybridValueIterationOrNode2 != null) {
                orprint(hybridValueIterationOrNode2, printStream);
            }
        }
    }
}
