package libpomdp.solve.offline.pointbased;

import com.jgoodies.forms.layout.FormSpec;
import java.util.Iterator;
import libpomdp.common.AlphaVector;
import libpomdp.common.BeliefState;
import libpomdp.common.CustomVector;
import libpomdp.common.Pomdp;
import libpomdp.common.rho.RhoPomdp;
import libpomdp.common.std.BeliefMdpStd;
import libpomdp.common.std.BeliefStateStd;
import libpomdp.common.std.PomdpStd;
import libpomdp.common.std.ValueFunctionStd;
import libpomdp.solve.offline.IterationStats;
import libpomdp.solve.offline.vi.ValueIterationStd;

/* loaded from: input_file:libpomdp/solve/offline/pointbased/PointBasedStd.class */
public class PointBasedStd extends ValueIterationStd {
    BeliefMdpStd bmdp;
    PointSet fullBset;
    PointSet newBset;
    PbParams params;

    public AlphaVector getLowestAlpha() {
        return new AlphaVector(CustomVector.getHomogene(this.bmdp.nrStates(), this.bmdp.getRewardMaxMin() / (1.0d - this.bmdp.getGamma())), -1);
    }

    public PointBasedStd(PomdpStd pomdpStd, PbParams pbParams) {
        startTimer();
        initValueIteration(pomdpStd);
        this.params = pbParams;
        this.bmdp = new BeliefMdpStd(pomdpStd);
        this.current = new ValueFunctionStd(pomdpStd.nrStates());
        this.current.push(getLowestAlpha());
        registerInitTime();
    }

    @Override // libpomdp.solve.offline.vi.ValueIterationStd, libpomdp.solve.offline.Iteration
    public IterationStats iterate() {
        startTimer();
        this.old = this.current;
        expand();
        if (this.bmdp.getPomdp() instanceof RhoPomdp) {
            ((RhoPomdp) this.bmdp.getPomdp()).approxReward(this.fullBset);
        }
        for (int i = 0; i < this.params.backupHorizon; i++) {
            backup();
        }
        this.current.prune();
        if (this.params.isNewPointsOnly()) {
            this.fullBset = this.newBset;
        }
        registerValueIterationStats();
        return this.iterationStats;
    }

    protected void backup() {
        switch (this.params.getBackupMethod()) {
            case 1:
                this.current = syncBackup(this.fullBset);
                return;
            case 2:
                this.current = syncBackup(this.newBset);
                return;
            case 3:
                this.current = asyncBackup(this.fullBset);
                return;
            case 4:
                this.current = asyncBackup(this.fullBset);
                return;
            default:
                return;
        }
    }

    private AlphaVector backup(BeliefState beliefState, ValueFunctionStd valueFunctionStd) {
        AlphaVector alphaVector = null;
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.bmdp.nrActions(); i++) {
            AlphaVector alphaVector2 = new AlphaVector(this.bmdp.nrStates(), i);
            for (int i2 = 0; i2 < this.bmdp.nrObservations(); i2++) {
                double d2 = Double.NEGATIVE_INFINITY;
                AlphaVector alphaVector3 = null;
                for (int i3 = 0; i3 < valueFunctionStd.size(); i3++) {
                    AlphaVector projection = this.bmdp.projection(valueFunctionStd.getAlphaVector(i3), i, i2);
                    double eval = projection.eval(beliefState);
                    if (eval > d2) {
                        d2 = eval;
                        alphaVector3 = projection;
                    }
                }
                alphaVector2.add(alphaVector3);
            }
            alphaVector2.add(this.bmdp.getRewardValueFunction(i).getBestAlpha(beliefState));
            double eval2 = alphaVector2.eval(beliefState);
            if (eval2 > d) {
                d = eval2;
                alphaVector = alphaVector2;
            }
        }
        return alphaVector;
    }

    private ValueFunctionStd asyncBackup(PointSet pointSet) {
        ValueFunctionStd valueFunctionStd = new ValueFunctionStd(this.bmdp.nrStates());
        PointSet copy = pointSet.copy();
        while (copy.size() != 0) {
            BeliefState random = copy.getRandom();
            pointSet.remove(random);
            AlphaVector backup = backup(random, this.old);
            if (backup.eval(random) >= this.old.V(random)) {
                valueFunctionStd.push(backup);
            } else {
                valueFunctionStd.push(this.old.getBestAlpha(random));
            }
            PointSet pointSet2 = new PointSet();
            Iterator<BeliefState> it = copy.iterator();
            while (it.hasNext()) {
                BeliefState next = it.next();
                if (valueFunctionStd.V(next) >= this.old.V(next)) {
                    pointSet2.add(next);
                }
            }
            Iterator<BeliefState> it2 = pointSet2.iterator();
            while (it2.hasNext()) {
                copy.remove(it2.next());
            }
        }
        return valueFunctionStd;
    }

    protected ValueFunctionStd syncBackup(PointSet pointSet) {
        ValueFunctionStd valueFunctionStd = new ValueFunctionStd(this.bmdp.nrStates());
        Iterator<BeliefState> it = pointSet.iterator();
        while (it.hasNext()) {
            valueFunctionStd.push(backup(it.next(), this.old));
        }
        return valueFunctionStd;
    }

    /* JADX WARN: Removed duplicated region for block: B:22:0x00bf  */
    /* JADX WARN: Removed duplicated region for block: B:25:0x00db A[SYNTHETIC] */
    /* JADX WARN: Removed duplicated region for block: B:33:0x0055 A[SYNTHETIC] */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    protected void expand() {
        /*
            r4 = this;
            r0 = r4
            libpomdp.solve.offline.pointbased.PointSet r1 = new libpomdp.solve.offline.pointbased.PointSet
            r2 = r1
            r2.<init>()
            r0.newBset = r1
            r0 = r4
            libpomdp.solve.offline.pointbased.PointSet r0 = r0.fullBset
            if (r0 != 0) goto L3b
            r0 = r4
            libpomdp.solve.offline.pointbased.PointSet r1 = new libpomdp.solve.offline.pointbased.PointSet
            r2 = r1
            r2.<init>()
            r0.fullBset = r1
            r0 = r4
            libpomdp.solve.offline.pointbased.PointSet r0 = r0.fullBset
            r1 = r4
            libpomdp.common.std.BeliefMdpStd r1 = r1.bmdp
            libpomdp.common.BeliefState r1 = r1.getInitialBeliefState()
            boolean r0 = r0.add(r1)
            r0 = r4
            libpomdp.solve.offline.pointbased.PointSet r0 = r0.newBset
            r1 = r4
            libpomdp.common.std.BeliefMdpStd r1 = r1.bmdp
            libpomdp.common.BeliefState r1 = r1.getInitialBeliefState()
            boolean r0 = r0.add(r1)
        L3b:
            r0 = r4
            libpomdp.solve.offline.pointbased.PointSet r0 = r0.fullBset
            int r0 = r0.size()
            r1 = r4
            libpomdp.solve.offline.pointbased.PbParams r1 = r1.params
            int r1 = r1.getMaxTotalPoints()
            if (r0 < r1) goto L4d
            return
        L4d:
            r0 = r4
            libpomdp.solve.offline.pointbased.PointSet r0 = r0.fullBset
            libpomdp.solve.offline.pointbased.PointSet r0 = r0.copy()
            r5 = r0
        L55:
            r0 = r4
            libpomdp.solve.offline.pointbased.PointSet r0 = r0.fullBset
            int r0 = r0.size()
            r1 = r4
            libpomdp.solve.offline.pointbased.PbParams r1 = r1.params
            int r1 = r1.getMaxTotalPoints()
            if (r0 < r1) goto L77
            r0 = r4
            libpomdp.solve.offline.pointbased.PointSet r0 = r0.newBset
            int r0 = r0.size()
            r1 = r4
            libpomdp.solve.offline.pointbased.PbParams r1 = r1.params
            int r1 = r1.getMaxNewPoints()
            if (r0 >= r1) goto Lf1
        L77:
            r0 = 0
            r6 = r0
            r0 = r4
            libpomdp.solve.offline.pointbased.PbParams r0 = r0.params
            int r0 = r0.getExpandMethod()
            switch(r0) {
                case 1: goto La0;
                case 2: goto La9;
                case 3: goto Lb2;
                case 4: goto Lb2;
                default: goto Lbb;
            }
        La0:
            r0 = r4
            r1 = r5
            libpomdp.common.std.BeliefStateStd r0 = r0.collectGreedyErrorReduction(r1)
            r6 = r0
            goto Lbb
        La9:
            r0 = r4
            r1 = r5
            libpomdp.common.std.BeliefStateStd r0 = r0.collectExploratoryAction(r1)
            r6 = r0
            goto Lbb
        Lb2:
            r0 = r5
            r1 = r4
            libpomdp.common.std.BeliefMdpStd r1 = r1.bmdp
            libpomdp.common.std.BeliefStateStd r0 = collectRandomExplore(r0, r1)
            r6 = r0
        Lbb:
            r0 = r6
            if (r0 == 0) goto Ld4
            r0 = r4
            libpomdp.solve.offline.pointbased.PointSet r0 = r0.fullBset
            r1 = r6
            boolean r0 = r0.add(r1)
            r0 = r4
            libpomdp.solve.offline.pointbased.PointSet r0 = r0.newBset
            r1 = r6
            libpomdp.common.BeliefState r1 = r1.copy()
            boolean r0 = r0.add(r1)
        Ld4:
            r0 = r5
            int r0 = r0.size()
            if (r0 != 0) goto Lee
            r0 = r4
            libpomdp.solve.offline.pointbased.PbParams r0 = r0.params
            int r0 = r0.getExpandMethod()
            r1 = 3
            if (r0 != r1) goto Lf1
            r0 = r4
            libpomdp.solve.offline.pointbased.PointSet r0 = r0.fullBset
            libpomdp.solve.offline.pointbased.PointSet r0 = r0.copy()
            r5 = r0
        Lee:
            goto L55
        Lf1:
            return
        */
        throw new UnsupportedOperationException("Method not decompiled: libpomdp.solve.offline.pointbased.PointBasedStd.expand():void");
    }

    private BeliefStateStd collectExploratoryAction(PointSet pointSet) {
        BeliefStateStd beliefStateStd = (BeliefStateStd) pointSet.remove(0);
        double d = Double.NEGATIVE_INFINITY;
        BeliefStateStd beliefStateStd2 = null;
        for (int i = 0; i < this.bmdp.nrActions(); i++) {
            BeliefStateStd beliefStateStd3 = (BeliefStateStd) this.bmdp.nextBeliefState(beliefStateStd, i, this.bmdp.getRandomObservation(beliefStateStd, i));
            double distance = distance(beliefStateStd3, this.fullBset);
            if (distance > d) {
                d = distance;
                beliefStateStd2 = beliefStateStd3;
            }
        }
        if (d == FormSpec.NO_GROW) {
            beliefStateStd2 = null;
        }
        return beliefStateStd2;
    }

    private BeliefStateStd collectGreedyErrorReduction(PointSet pointSet) {
        double d = Double.NEGATIVE_INFINITY;
        BeliefState beliefState = null;
        int i = -1;
        int i2 = -1;
        Iterator<BeliefState> it = pointSet.iterator();
        while (it.hasNext()) {
            BeliefState next = it.next();
            for (int i3 = 0; i3 < this.bmdp.nrActions(); i3++) {
                double d2 = 0.0d;
                for (int i4 = 0; i4 < this.bmdp.nrObservations(); i4++) {
                    d2 += this.bmdp.getTau(i3, i4).mult(next.getPoint()).norm(1.0d) * minError(this.bmdp.nextBeliefState(next, i3, i4), pointSet);
                }
                if (d2 > d) {
                    d = d2;
                    beliefState = next;
                    i = i3;
                }
            }
        }
        double d3 = Double.NEGATIVE_INFINITY;
        for (int i5 = 0; i5 < this.bmdp.nrObservations(); i5++) {
            double norm = this.bmdp.getTau(i, i5).mult(beliefState.getPoint()).norm(1.0d) * minError(this.bmdp.nextBeliefState(beliefState, i, i5), pointSet);
            if (norm > d3) {
                d3 = norm;
                i2 = i5;
            }
        }
        pointSet.remove(beliefState);
        return (BeliefStateStd) this.bmdp.nextBeliefState(beliefState, i, i2);
    }

    private double minError(BeliefState beliefState, PointSet pointSet) {
        double d;
        double d2;
        double d3;
        double rewardMax = this.bmdp.getRewardMax() / (1.0d - this.bmdp.getGamma());
        double rewardMin = this.bmdp.getRewardMin() / (1.0d - this.bmdp.getGamma());
        double d4 = Double.POSITIVE_INFINITY;
        Iterator<BeliefState> it = pointSet.iterator();
        while (it.hasNext()) {
            BeliefState next = it.next();
            double d5 = 0.0d;
            AlphaVector bestAlpha = this.current.getBestAlpha(next);
            for (int i = 0; i < this.bmdp.nrStates(); i++) {
                double d6 = beliefState.getPoint().get(i) - next.getPoint().get(i);
                if (d6 >= FormSpec.NO_GROW) {
                    d = d5;
                    d2 = rewardMax;
                    d3 = bestAlpha.getVectorRef().get(i);
                } else {
                    d = d5;
                    d2 = rewardMin;
                    d3 = bestAlpha.getVectorRef().get(i);
                }
                d5 = d + ((d2 - d3) * d6);
            }
            if (d5 < d4) {
                d4 = d5;
            }
        }
        return d4;
    }

    public static BeliefStateStd collectRandomExplore(PointSet pointSet, Pomdp pomdp) {
        BeliefStateStd beliefStateStd = (BeliefStateStd) pointSet.remove(0);
        int randomAction = ((PomdpStd) pomdp).getRandomAction();
        return (BeliefStateStd) pomdp.nextBeliefState(beliefStateStd, randomAction, ((PomdpStd) pomdp).getRandomObservation(beliefStateStd, randomAction));
    }

    private double distance(BeliefStateStd beliefStateStd, PointSet pointSet) {
        double d = Double.POSITIVE_INFINITY;
        Iterator<BeliefState> it = pointSet.iterator();
        while (it.hasNext()) {
            CustomVector copy = it.next().getPoint().copy();
            copy.add(-1.0d, beliefStateStd.getPoint());
            double norm = copy.norm(1.0d);
            if (norm < d) {
                d = norm;
            }
        }
        return d;
    }
}
