package libpomdp.solve.offline.exact;

import java.util.ArrayList;
import libpomdp.common.AlphaVector;
import libpomdp.common.std.BeliefMdpStd;
import libpomdp.common.std.PomdpStd;
import libpomdp.common.std.ValueFunctionStd;
import libpomdp.solve.offline.IterationStats;
import libpomdp.solve.offline.vi.ValueIterationStats;
import libpomdp.solve.offline.vi.ValueIterationStd;

/* loaded from: input_file:libpomdp/solve/offline/exact/IncrementalPruningStd.class */
public class IncrementalPruningStd extends ValueIterationStd {
    BeliefMdpStd bmdp;
    private double delta;

    public IncrementalPruningStd(PomdpStd pomdpStd, double d) {
        startTimer();
        initValueIteration(pomdpStd);
        this.delta = d;
        this.bmdp = new BeliefMdpStd(pomdpStd);
        this.current = new ValueFunctionStd(pomdpStd.nrStates());
        this.current.push(new AlphaVector(this.bmdp.nrStates()));
        registerInitTime();
    }

    @Override // libpomdp.solve.offline.vi.ValueIterationStd, libpomdp.solve.offline.Iteration
    public IterationStats iterate() {
        startTimer();
        this.old = this.current;
        ValueIterationStats valueIterationStats = (ValueIterationStats) this.iterationStats;
        this.current = new ValueFunctionStd(this.bmdp.nrStates());
        for (int i = 0; i < this.bmdp.nrActions(); i++) {
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < this.bmdp.nrObservations(); i2++) {
                ValueFunctionStd valueFunctionStd = new ValueFunctionStd(this.bmdp.nrStates());
                for (int i3 = 0; i3 < this.old.size(); i3++) {
                    valueFunctionStd.push(this.bmdp.projection(this.old.getAlphaVector(i3), i, i2));
                }
                valueIterationStats.registerLp(valueFunctionStd.prune(this.delta));
                arrayList.add(valueFunctionStd);
            }
            arrayList.add(this.bmdp.getRewardValueFunction(i));
            while (arrayList.size() > 1) {
                ValueFunctionStd valueFunctionStd2 = (ValueFunctionStd) arrayList.remove(0);
                valueFunctionStd2.crossSum((ValueFunctionStd) arrayList.remove(0));
                valueIterationStats.registerLp(valueFunctionStd2.prune(this.delta));
                arrayList.add(valueFunctionStd2);
            }
            this.current.merge((ValueFunctionStd) arrayList.remove(0));
        }
        valueIterationStats.registerLp(this.current.prune(this.delta));
        System.out.println(this.current.size());
        registerValueIterationStats();
        return valueIterationStats;
    }
}
