/*
 * Decompiled with CFR 0.152.
 */
package jmarkov.jmdp.solvers;

import java.io.PrintWriter;
import java.util.Iterator;
import java.util.Map;
import jmarkov.basic.Action;
import jmarkov.basic.Actions;
import jmarkov.basic.DecisionRule;
import jmarkov.basic.Policy;
import jmarkov.basic.Solution;
import jmarkov.basic.State;
import jmarkov.basic.States;
import jmarkov.basic.StatesSet;
import jmarkov.basic.ValueFunction;
import jmarkov.basic.exceptions.NonStochasticException;
import jmarkov.basic.exceptions.SolverException;
import jmarkov.jmdp.DTMDP;
import jmarkov.jmdp.InfiniteMDP;
import jmarkov.jmdp.solvers.AbstractAverageSolver;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.Vector;
import no.uib.cipr.matrix.sparse.BiCG;
import no.uib.cipr.matrix.sparse.FlexCompRowMatrix;
import no.uib.cipr.matrix.sparse.IterativeSolverNotConvergedException;
import no.uib.cipr.matrix.sparse.SparseVector;

public class PolicyIterationSolverAvg<S extends State, A extends Action>
extends AbstractAverageSolver<S, A> {
    private DenseVector costs;
    private boolean isOptimal = false;
    protected long iterations;
    protected long processTime = 0L;
    private DenseVector vecValueFunction = null;
    private FlexCompRowMatrix matrix = null;
    private DecisionRule<S, A> currentDecisionRule = null;
    private ValueFunction<S> relativeValueFunction = new ValueFunction();
    private double gain = 0.0;
    boolean printBias = false;
    boolean printGain = false;

    public PolicyIterationSolverAvg(DTMDP<S, A> problem) {
        super(problem);
    }

    @Override
    public Solution<S, A> solve() throws SolverException {
        long initialTime = System.currentTimeMillis();
        this.currentDecisionRule = this.initialDecisionRuleFirst();
        this.policy = new Policy(this.currentDecisionRule);
        this.vecValueFunction = new DenseVector(this.getDiscreteProblem().getNumStates());
        this.matrix = this.buildMatrix(this.currentDecisionRule);
        this.iterations = 0L;
        while (!this.isOptimal) {
            this.problem.debug(2, "Iteration " + this.iterations);
            this.getProblem().debug(3, "Current Rule = " + this.currentDecisionRule);
            this.valueFunction = this.policyEvaluation();
            this.getProblem().debug(3, "Current Value function = " + this.valueFunction);
            this.currentDecisionRule = this.policyImprovement();
            this.policy.setDecisionRule(this.currentDecisionRule);
            ++this.iterations;
        }
        this.policy = new Policy(this.currentDecisionRule);
        this.updateResults(this.valueFunction);
        this.solved = true;
        this.processTime = System.currentTimeMillis() - initialTime;
        return new Solution(this.valueFunction, this.policy);
    }

    private DecisionRule<S, A> initialDecisionRuleFirst() {
        this.valueFunction = new ValueFunction();
        DecisionRule<State, Action> localDecisionRule = new DecisionRule<State, Action>();
        StatesSet states = ((InfiniteMDP)this.getProblem()).getAllStates();
        for (State i : states) {
            Actions availableActions = ((InfiniteMDP)this.getProblem()).feasibleActions(i);
            Iterator it = availableActions.iterator();
            if (it.hasNext()) {
                Action a = (Action)it.next();
                localDecisionRule.set(i, a);
                if (i.getIndex() == 9) {
                    a = (Action)it.next();
                    localDecisionRule.set(i, a);
                }
            }
            this.valueFunction.set(i, 0.0);
        }
        return localDecisionRule;
    }

    private ValueFunction<S> policyEvaluation() throws SolverException {
        this.valueFunction = this.solveMatrix();
        return this.valueFunction;
    }

    private DecisionRule<S, A> policyImprovement() throws SolverException {
        StatesSet sts = ((InfiniteMDP)this.getProblem()).getAllStates();
        DecisionRule<S, A> newDecisionRule = new DecisionRule<S, A>(this.currentDecisionRule);
        Iterator<Map.Entry<S, A>> itCurDR = this.currentDecisionRule.iterator();
        Iterator<Map.Entry<S, A>> itNewDR = newDecisionRule.iterator();
        for (State i : sts) {
            Actions actions = ((InfiniteMDP)this.getProblem()).feasibleActions(i);
            Action bestAction = null;
            double bestValue = Double.MAX_VALUE;
            for (Action a : actions) {
                double val = this.getProblem().operation(this.getDiscreteProblem().immediateCost(i, a), this.future(i, a));
                if (!(val < bestValue)) continue;
                bestValue = val;
                bestAction = a;
            }
            Map.Entry<S, A> curDRentry = itCurDR.next();
            Action curAction = (Action)curDRentry.getValue();
            if (!bestAction.equals(curAction)) {
                this.matrix.setRow(i.getIndex(), this.buildRowVector(i, bestAction));
                this.costs.set(i.getIndex(), this.getDiscreteProblem().immediateCost(i, bestAction));
            }
            Map.Entry<S, A> newDRentry = itNewDR.next();
            newDRentry.setValue(bestAction);
        }
        for (State i : sts) {
            this.matrix.set(i.getIndex(), 0, 1.0);
        }
        this.isOptimal = this.currentDecisionRule.equals(newDecisionRule);
        return newDecisionRule;
    }

    private SparseVector buildRowVector(S i, A a) {
        int n = this.getDiscreteProblem().getNumStates();
        States reachableStates = this.getDiscreteProblem().reachable(i, a);
        SparseVector vec = new SparseVector(n, reachableStates.size());
        double sum = 0.0;
        for (State j : reachableStates) {
            double probability = this.getDiscreteProblem().prob((State)i, j, a);
            sum += probability;
            assert (probability >= 0.0);
            j = this.getDiscreteProblem().getAllStates().get(j);
            if (!(probability > 0.0)) continue;
            vec.set(j.getIndex(), probability);
        }
        vec.scale(-1.0);
        vec.add(((State)i).getIndex(), 1.0);
        if (Math.abs(sum - 1.0) > 1.0E-5) {
            throw new NonStochasticException("Probabilities do not add up to 1 for state " + i + ", and action " + a + ", sum = " + sum);
        }
        return vec;
    }

    private FlexCompRowMatrix buildMatrix(DecisionRule<S, A> currentDecisionRule) {
        StatesSet stts = this.getDiscreteProblem().getAllStates();
        int n = stts.size();
        FlexCompRowMatrix matrix = new FlexCompRowMatrix(n, n);
        this.costs = new DenseVector(n);
        for (State i : stts) {
            A a = currentDecisionRule.getAction(i);
            matrix.setRow(i.getIndex(), this.buildRowVector(i, a));
            this.costs.set(i.getIndex(), this.getDiscreteProblem().immediateCost(i, a));
        }
        for (State i : stts) {
            matrix.set(i.getIndex(), 0, 1.0);
        }
        return matrix;
    }

    protected ValueFunction<S> solveMatrix() throws SolverException {
        this.getProblem().debug(4, "Matrix to solve:\n" + this.matrix);
        try {
            BiCG solver = new BiCG((Vector)this.vecValueFunction);
            solver.solve((Matrix)this.matrix, (Vector)this.costs, (Vector)this.vecValueFunction);
        }
        catch (IterativeSolverNotConvergedException e) {
            throw new SolverException("Policy iteration Solver: error solving linear system.", e);
        }
        return this.buildValueFunction(this.vecValueFunction);
    }

    protected final double future(S i, A a, ValueFunction<S> vf) {
        double sum = 0.0;
        States reachableStates = this.getDiscreteProblem().reachable(i, a);
        for (State j : reachableStates) {
            sum += this.getDiscreteProblem().prob((State)i, j, a) * this.getGain();
        }
        return sum;
    }

    protected final double future(S i, A a) {
        return this.future(i, a, this.getValueFunction());
    }

    private ValueFunction<S> buildValueFunction(DenseVector vec) {
        ValueFunction<State> vf = new ValueFunction<State>();
        StatesSet stts = this.getDiscreteProblem().getAllStates();
        int i = 0;
        for (State s : stts) {
            vf.set(s, vec.get(i));
            ++i;
        }
        return vf;
    }

    protected final double future(S i, A a, double discountF, ValueFunction<S> vf) {
        double sum = 0.0;
        States reachableStates = this.getDiscreteProblem().reachable(i, a);
        for (State j : reachableStates) {
            sum += this.getDiscreteProblem().prob((State)i, j, a) * vf.get(j);
        }
        return discountF * sum;
    }

    @Override
    public String description() {
        return "Policy Iteration Solver\n";
    }

    @Override
    public String label() {
        return "Policy Iter. Solver(avg)";
    }

    @Override
    public final long getProcessTime() {
        return this.processTime;
    }

    @Override
    public final long getIterations() {
        return this.iterations;
    }

    private void updateResults(ValueFunction<S> valueFunction) {
        StatesSet stts = this.getDiscreteProblem().getAllStates();
        Iterator it = stts.iterator();
        State s = (State)it.next();
        this.gain = valueFunction.get(s);
        this.relativeValueFunction.set(s, 0.0);
        while (it.hasNext()) {
            s = (State)it.next();
            this.relativeValueFunction.set(s, valueFunction.get(s));
        }
    }

    public final double getGain() {
        return this.gain;
    }

    public final ValueFunction<S> getBias() {
        return this.relativeValueFunction;
    }

    public void setPrintBias(boolean val) {
        this.printBias = val;
    }

    public void setPrintGain(boolean val) {
        this.printGain = val;
    }

    @Override
    public void printSolution(PrintWriter pw) {
        pw.println(this);
        try {
            this.getOptimalPolicy().print(pw);
            if (this.printValueFunction) {
                this.valueFunction.print(pw);
            }
            if (this.printBias) {
                pw.println("Bias for each state:");
            }
            this.relativeValueFunction.print(pw);
            if (this.printGain) {
                pw.println("Gain = " + this.gain);
            }
            if (this.printProcessTime) {
                pw.println("Process time = " + this.getProcessTime() + " milliseconds");
            }
        }
        catch (SolverException e) {
            pw.print(" Error solving the problem :" + e);
        }
    }

    @Override
    public void printSolution() throws Exception {
        this.printSolution(new PrintWriter(System.out));
    }
}

