package libpomdp.common.add;

import com.jgoodies.forms.layout.FormSpec;
import java.util.ArrayList;
import libpomdp.common.BeliefState;
import libpomdp.common.CustomMatrix;
import libpomdp.common.CustomVector;
import libpomdp.common.Pomdp;
import libpomdp.common.Utils;
import libpomdp.common.add.symbolic.DD;
import libpomdp.common.add.symbolic.OP;
import libpomdp.parser.ParseSPUDD;
import org.math.array.DoubleArray;
import org.math.array.IntegerArray;

/* loaded from: input_file:libpomdp/common/add/PomdpAdd.class */
public class PomdpAdd implements Pomdp {
    private int nrStaV;
    private int[] staIds;
    private int[] staIdsPr;
    public int[] staArity;
    private int totnrSta;
    private int nrObsV;
    public int[] obsIds;
    private int[] obsIdsPr;
    private int[] obsArity;
    private int totnrObs;
    private int nrTotV;
    private int nrAct;
    public DD[][] T;
    public DD[][] O;
    public DD[] R;
    private double gamma;
    private String[] actStr;
    private BeliefStateAdd initBelief;
    public ParseSPUDD problemAdd;

    /* JADX WARN: Type inference failed for: r1v39, types: [libpomdp.common.add.symbolic.DD[], libpomdp.common.add.symbolic.DD[][]] */
    /* JADX WARN: Type inference failed for: r1v42, types: [libpomdp.common.add.symbolic.DD[], libpomdp.common.add.symbolic.DD[][]] */
    public PomdpAdd(String str) {
        this.problemAdd = new ParseSPUDD(str);
        this.problemAdd.parsePOMDP(false);
        this.nrStaV = this.problemAdd.nStateVars;
        this.nrObsV = this.problemAdd.nObsVars;
        this.nrTotV = this.nrStaV + this.nrObsV;
        this.nrAct = this.problemAdd.actTransitions.size();
        this.gamma = this.problemAdd.discount.getVal();
        this.staIds = new int[this.nrStaV];
        this.staIdsPr = new int[this.nrStaV];
        this.staArity = new int[this.nrStaV];
        this.obsIds = new int[this.nrObsV];
        this.obsIdsPr = new int[this.nrObsV];
        this.obsArity = new int[this.nrObsV];
        this.T = new DD[this.nrAct];
        this.O = new DD[this.nrAct];
        this.R = new DD[this.nrAct];
        this.actStr = new String[this.nrAct];
        for (int i = 0; i < this.nrStaV; i++) {
            this.staIds[i] = i + 1;
            this.staIdsPr[i] = i + 1 + this.nrTotV;
            this.staArity[i] = this.problemAdd.valNames.get(i).size();
        }
        for (int i2 = 0; i2 < this.nrObsV; i2++) {
            this.obsIds[i2] = this.nrStaV + i2 + 1;
            this.obsIdsPr[i2] = this.nrStaV + i2 + 1 + this.nrTotV;
            this.obsArity[i2] = this.problemAdd.valNames.get(this.nrStaV + i2).size();
        }
        for (int i3 = 0; i3 < this.nrAct; i3++) {
            this.T[i3] = this.problemAdd.actTransitions.get(i3);
            this.O[i3] = this.problemAdd.actObserve.get(i3);
            this.R[i3] = OP.sub(this.problemAdd.reward, this.problemAdd.actCosts.get(i3));
            this.actStr[i3] = this.problemAdd.actNames.get(i3);
        }
        this.initBelief = new BeliefStateAdd(this.problemAdd.init, this.staIds, FormSpec.NO_GROW);
        this.totnrSta = IntegerArray.product(this.staArity);
        this.totnrObs = IntegerArray.product(this.obsArity);
    }

    @Override // libpomdp.common.Pomdp
    public BeliefState nextBeliefState(BeliefState beliefState, int i, int i2) {
        return beliefState instanceof BeliefStateAdd ? regulartao((BeliefStateAdd) beliefState, i, i2) : factoredtao((BeliefStateFactoredAdd) beliefState, i, i2);
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v5, types: [libpomdp.common.add.symbolic.DD[], libpomdp.common.add.symbolic.DD[][]] */
    public BeliefState regulartao(BeliefStateAdd beliefStateAdd, int i, int i2) {
        DD primeVars = OP.primeVars(OP.addMultVarElim(Utils.concat(beliefStateAdd.bAdd, (DD[][]) new DD[]{this.T[i], OP.restrictN(this.O[i], IntegerArray.mergeRows((int[][]) new int[]{this.obsIdsPr, Utils.sdecode(i2, this.nrObsV, this.obsArity)}))}), this.staIds), -this.nrTotV);
        DD addMultVarElim = OP.addMultVarElim(primeVars, this.staIds);
        return addMultVarElim.getVal() < 1.0E-5d ? this.initBelief : new BeliefStateAdd(OP.div(primeVars, addMultVarElim), this.staIds, addMultVarElim.getVal());
    }

    /* JADX WARN: Type inference failed for: r0v6, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v5, types: [libpomdp.common.add.symbolic.DD[], libpomdp.common.add.symbolic.DD[][]] */
    public BeliefState factoredtao(BeliefStateFactoredAdd beliefStateFactoredAdd, int i, int i2) {
        DD[] ddArr = beliefStateFactoredAdd.marginals;
        DD[] ddArr2 = new DD[this.nrStaV];
        DD[] marginals = OP.marginals(Utils.concat(ddArr, (DD[][]) new DD[]{this.T[i], OP.restrictN(this.O[i], IntegerArray.mergeRows((int[][]) new int[]{this.obsIdsPr, Utils.sdecode(i2, this.nrObsV, this.obsArity)}))}), this.staIdsPr, this.staIds);
        for (int i3 = 0; i3 < this.nrStaV; i3++) {
            ddArr2[i3] = OP.primeVars(marginals[i3], -this.nrTotV);
        }
        return new BeliefStateFactoredAdd(ddArr2, this.staIds);
    }

    @Override // libpomdp.common.Pomdp
    public double expectedImmediateReward(BeliefState beliefState, int i) {
        return beliefState instanceof BeliefStateAdd ? OP.dotProductNoMem(((BeliefStateAdd) beliefState).bAdd, this.R[i], this.staIds) : OP.factoredExpectationSparseNoMem(((BeliefStateFactoredAdd) beliefState).marginals, this.R[i]);
    }

    /* JADX WARN: Type inference failed for: r0v8, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v1, types: [libpomdp.common.add.symbolic.DD[], libpomdp.common.add.symbolic.DD[][]] */
    @Override // libpomdp.common.Pomdp
    public CustomVector observationProbabilities(BeliefState beliefState, int i) {
        return new CustomVector(OP.convert2array(OP.addMultVarElim(Utils.concat(beliefState instanceof BeliefStateAdd ? new DD[]{((BeliefStateAdd) beliefState).bAdd} : ((BeliefStateFactoredAdd) beliefState).marginals, (DD[][]) new DD[]{this.T[i], this.O[i]}), IntegerArray.merge((int[][]) new int[]{this.staIds, this.staIdsPr})), this.obsIdsPr));
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [int[], int[][]] */
    @Override // libpomdp.common.Pomdp
    public CustomMatrix getTransitionTable(int i) {
        double[] convert2array = OP.convert2array(OP.multN(this.T[i]), IntegerArray.merge((int[][]) new int[]{this.staIds, this.staIdsPr}));
        double[][] fill = DoubleArray.fill(this.totnrSta, this.totnrSta, FormSpec.NO_GROW);
        for (int i2 = 0; i2 < this.totnrSta; i2++) {
            for (int i3 = 0; i3 < this.totnrSta; i3++) {
                fill[i3][i2] = convert2array[(i2 * this.totnrSta) + i3];
            }
        }
        return new CustomMatrix(fill);
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [int[], int[][]] */
    @Override // libpomdp.common.Pomdp
    public CustomMatrix getObservationTable(int i) {
        double[] convert2array = OP.convert2array(OP.multN(this.O[i]), IntegerArray.merge((int[][]) new int[]{this.staIdsPr, this.obsIdsPr}));
        double[][] fill = DoubleArray.fill(this.totnrSta, this.totnrObs, FormSpec.NO_GROW);
        for (int i2 = 0; i2 < this.totnrObs; i2++) {
            for (int i3 = 0; i3 < this.totnrSta; i3++) {
                fill[i3][i2] = convert2array[(i2 * this.totnrSta) + i3];
            }
        }
        return new CustomMatrix(fill);
    }

    @Override // libpomdp.common.Pomdp
    public CustomVector getRewardTable(int i) {
        return new CustomVector(OP.convert2array(OP.sub(this.problemAdd.reward, this.problemAdd.actCosts.get(i)), this.staIds));
    }

    @Override // libpomdp.common.Pomdp
    public BeliefState getInitialBeliefState() {
        return this.initBelief;
    }

    @Override // libpomdp.common.Pomdp
    public int nrStates() {
        return this.totnrSta;
    }

    @Override // libpomdp.common.Pomdp
    public int nrActions() {
        return this.nrAct;
    }

    @Override // libpomdp.common.Pomdp
    public int nrObservations() {
        return this.totnrObs;
    }

    @Override // libpomdp.common.Pomdp
    public double getGamma() {
        return this.gamma;
    }

    @Override // libpomdp.common.Pomdp
    public String getActionString(int i) {
        return this.actStr[i];
    }

    @Override // libpomdp.common.Pomdp
    public String getObservationString(int i) {
        int[] sdecode = Utils.sdecode(i, this.nrObsV, this.obsArity);
        String str = "";
        for (int i2 = 0; i2 < this.nrObsV; i2++) {
            str = str.concat(this.problemAdd.varNames.get(this.nrStaV + i2) + "=" + this.problemAdd.valNames.get(this.nrStaV + i2).get(sdecode[i2] - 1) + ", ");
        }
        return str;
    }

    @Override // libpomdp.common.Pomdp
    public String getStateString(int i) {
        return null;
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [int[], int[][]] */
    public int[] sampleNextState(int[] iArr, int i) {
        int[][] sampleMultinomial = OP.sampleMultinomial(OP.restrictN(this.T[i], IntegerArray.mergeRows((int[][]) new int[]{this.staIds, iArr})), this.staIdsPr);
        System.out.println(IntegerArray.toString(sampleMultinomial));
        return sampleMultinomial[1];
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r0v4, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r0v7, types: [int[], int[][]] */
    public int[] sampleObservation(int[] iArr, int[] iArr2, int i) {
        return OP.sampleMultinomial(OP.restrictN(this.O[i], IntegerArray.mergeRows((int[][]) new int[]{IntegerArray.merge((int[][]) new int[]{this.staIds, this.staIdsPr}), IntegerArray.merge((int[][]) new int[]{iArr, iArr2})})), this.obsIdsPr)[1];
    }

    /* JADX WARN: Type inference failed for: r0v11, types: [int[], int[][]] */
    public int[] getListofInitStates() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.totnrSta; i++) {
            if (OP.eval(this.initBelief.bAdd, IntegerArray.mergeRows((int[][]) new int[]{this.staIds, Utils.sdecode(i, this.nrStaV, this.staArity)})) > FormSpec.NO_GROW) {
                arrayList.add(Integer.valueOf(i));
            }
        }
        int[] iArr = new int[arrayList.size()];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr[i2] = ((Integer) arrayList.get(i2)).intValue();
        }
        return iArr;
    }

    public int getnrTotV() {
        return this.nrTotV;
    }

    public int getnrStaV() {
        return this.nrStaV;
    }

    public int getnrObsV() {
        return this.nrObsV;
    }

    public int[] getobsIdsPr() {
        return this.obsIdsPr;
    }

    public int[] getstaIds() {
        return this.staIds;
    }

    public int[] getstaIdsPr() {
        return this.staIdsPr;
    }

    public int[] getstaArity() {
        return this.staArity;
    }

    public int[] getobsArity() {
        return this.obsArity;
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v7, types: [libpomdp.common.add.symbolic.DD[], libpomdp.common.add.symbolic.DD[][]] */
    public DD gao(DD dd, int i, int i2) {
        return OP.addMultVarElim(Utils.concat(OP.primeVars(dd, this.nrTotV), (DD[][]) new DD[]{this.T[i], OP.restrictN(this.O[i], IntegerArray.mergeRows((int[][]) new int[]{this.obsIdsPr, Utils.sdecode(i2, this.nrObsV, this.obsArity)}))}), this.staIdsPr);
    }

    public String printS(int[][] iArr) {
        if (iArr.length != 2 || iArr[0].length != this.nrStaV) {
            System.err.println("Unexpected factored state matrix");
            return null;
        }
        String str = "";
        for (int i = 0; i < this.nrStaV; i++) {
            str = str.concat(this.problemAdd.varNames.get(i) + "=" + this.problemAdd.valNames.get(i).get(iArr[1][i] - 1) + ", ");
        }
        return str;
    }

    public String printO(int[][] iArr) {
        if (iArr.length != 2 || iArr[0].length != this.nrObsV) {
            System.err.println("Unexpected factored state matrix");
            return null;
        }
        String str = "";
        for (int i = 0; i < this.nrObsV; i++) {
            str = str.concat(this.problemAdd.varNames.get(this.nrStaV + i) + "=" + this.problemAdd.valNames.get(this.nrStaV + i).get(iArr[1][i] - 1) + ", ");
        }
        return str;
    }
}
