package mgjpomdp.solve;

import mgjpomdp.common.MDPUtils;
import mgjpomdp.common.POMDPFlatMTJ;
import mgjpomdp.common.PolicyMTJ;
import mgjpomdp.common.PolicyType;
import no.uib.cipr.matrix.sparse.SparseVector;

/* loaded from: input_file:mgjpomdp/solve/ValueIterationMTJ.class */
public class ValueIterationMTJ {
    public double[] _V;
    public double[] _initV;
    public double[][] _Q;
    public double[][] _initQ;

    /* JADX WARN: Type inference failed for: r1v8, types: [double[], double[][]] */
    public void solve(POMDPFlatMTJ pOMDPFlatMTJ, double d, int i) throws Exception {
        double d2;
        double computeVmax = pOMDPFlatMTJ.computeVmax();
        if (this._initV != null) {
            this._V = (double[]) this._initV.clone();
        } else {
            this._V = new double[pOMDPFlatMTJ._numS];
            for (int i2 = 0; i2 < pOMDPFlatMTJ._numS; i2++) {
                this._V[i2] = computeVmax;
            }
        }
        double[] dArr = new double[pOMDPFlatMTJ._numS];
        this._Q = new double[pOMDPFlatMTJ._numA];
        for (int i3 = 0; i3 < pOMDPFlatMTJ._numA; i3++) {
            this._Q[i3] = new double[pOMDPFlatMTJ._numS];
        }
        int i4 = 0;
        do {
            double[] dArr2 = dArr;
            dArr = this._V;
            this._V = dArr2;
            for (int i5 = 0; i5 < pOMDPFlatMTJ._numA; i5++) {
                double[] dArr3 = this._Q[i5];
                pOMDPFlatMTJ._TRow[i5].mult(pOMDPFlatMTJ._gamma, dArr, dArr3);
                MDPUtils.add(pOMDPFlatMTJ._R[i5], pOMDPFlatMTJ._nonZeroR[i5], dArr3);
            }
            for (int i6 = 0; i6 < pOMDPFlatMTJ._numS; i6++) {
                double d3 = Double.NEGATIVE_INFINITY;
                for (int i7 = 0; i7 < pOMDPFlatMTJ._numA; i7++) {
                    double d4 = this._Q[i7][i6];
                    if (d4 > d3) {
                        d3 = d4;
                    }
                }
                this._V[i6] = d3;
            }
            d2 = Double.NEGATIVE_INFINITY;
            for (int i8 = 0; i8 < pOMDPFlatMTJ._numS; i8++) {
                double abs = Math.abs(dArr[i8] - this._V[i8]);
                if (abs > d2) {
                    d2 = abs;
                }
            }
            if (i > 0) {
                System.out.println("Bellman error: " + d2 + " after " + i4 + " iterations.");
                System.out.println("V-function: " + POMDPFlatMTJ.toString(this._V));
                for (int i9 = 0; i9 < pOMDPFlatMTJ._numA; i9++) {
                    for (int i10 = 0; i10 < pOMDPFlatMTJ._numS; i10++) {
                        System.out.println("Q[a=" + i9 + ",s=" + i10 + "]: " + this._Q[i9][i10]);
                    }
                }
            }
            i4++;
        } while (d2 > d);
        System.out.println("VI execution summary:\nnumber of iterations: " + i4 + "\nfinal error: " + d2);
    }

    public static void main(String[] strArr) throws Exception {
        System.out.println("\nStarted...\n");
        POMDPFlatMTJ pOMDPFlatMTJ = new POMDPFlatMTJ("/home/mgrzes/_data/Cassandra_POMDPs/1d.POMDP", 0);
        System.out.println(pOMDPFlatMTJ.toString());
        ValueIterationMTJ valueIterationMTJ = new ValueIterationMTJ();
        valueIterationMTJ.solve(pOMDPFlatMTJ, (1.0E-6d * (1.0d - pOMDPFlatMTJ._gamma)) / (2.0d * pOMDPFlatMTJ._gamma), 1);
        PolicyMTJ policyMTJ = new PolicyMTJ(valueIterationMTJ, pOMDPFlatMTJ, PolicyType.QMDP);
        SparseVector sparseVector = new SparseVector(pOMDPFlatMTJ._numS);
        double d = 1.0d / pOMDPFlatMTJ._numS;
        for (int i = 0; i < pOMDPFlatMTJ._numS; i++) {
            sparseVector.set(i, d);
        }
        System.out.println("MDP value: " + policyMTJ.getV_MDP(sparseVector));
        System.out.println("QMDP value: " + policyMTJ.getV_QMDP(sparseVector));
    }
}
