/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.inference;

import cc.mallet.grmm.inference.Sampler;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.DiscreteFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.LogTableFactor;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.Randoms;
import cc.mallet.util.Timing;
import java.util.List;

public class GibbsSampler
implements Sampler {
    private int burnin;
    private Factor[] allCpts;
    private Randoms r = new Randoms(324231);

    public GibbsSampler() {
    }

    public GibbsSampler(int burnin) {
        this.burnin = burnin;
    }

    public GibbsSampler(Randoms r, int burnin) {
        this.burnin = burnin;
        this.r = r;
    }

    public void setBurnin(int burnin) {
        this.burnin = burnin;
    }

    @Override
    public void setRandom(Randoms r) {
        this.r = r;
    }

    @Override
    public Assignment sample(FactorGraph mdl, int N) {
        Assignment assn = this.initialAssignment(mdl);
        if (assn == null) {
            throw new IllegalArgumentException("GibbsSampler: Could not find feasible assignment for model " + mdl);
        }
        Timing timing = new Timing();
        for (int i = 0; i < this.burnin; ++i) {
            assn = this.doOnePass(mdl, assn);
        }
        timing.tick("Burnin");
        Assignment ret = new Assignment();
        for (int i = 0; i < N; ++i) {
            assn = this.doOnePass(mdl, assn);
            ret.addRow(assn);
        }
        timing.tick("Sampling");
        return ret;
    }

    private Assignment initialAssignment(FactorGraph mdl) {
        Assignment assn = new Assignment(mdl, new int[mdl.numVariables()]);
        if (mdl.logValue(assn) > Double.NEGATIVE_INFINITY) {
            return assn;
        }
        assn = new Assignment();
        return this.initialAssignmentRec(mdl, assn, 0);
    }

    private Assignment initialAssignmentRec(FactorGraph mdl, Assignment assn, int fi) {
        if (fi >= mdl.factors().size()) {
            return assn;
        }
        Factor f = mdl.getFactor(fi);
        Factor sliced = f.slice(assn);
        if (sliced.varSet().isEmpty()) {
            double val = f.value(assn);
            if (val > 1.0E-50) {
                return this.initialAssignmentRec(mdl, assn, fi + 1);
            }
            return null;
        }
        AssignmentIterator it = sliced.assignmentIterator();
        while (it.hasNext()) {
            Assignment new_assn;
            Assignment assn_ret;
            double val = sliced.value(it);
            if (val > 1.0E-50 && (assn_ret = this.initialAssignmentRec(mdl, new_assn = Assignment.union(assn, it.assignment()), fi + 1)) != null) {
                return assn_ret;
            }
            it.advance();
        }
        return null;
    }

    private Assignment doOnePass(FactorGraph mdl, Assignment initial) {
        Assignment ret = (Assignment)initial.duplicate();
        for (int vidx = 0; vidx < ret.size(); ++vidx) {
            Variable var = mdl.get(vidx);
            DiscreteFactor subcpt = this.constructConditionalCpt(mdl, var, ret);
            int value = subcpt.sampleLocation(this.r);
            ret.setValue(var, value);
        }
        return ret;
    }

    private DiscreteFactor constructConditionalCpt(FactorGraph mdl, Variable var, Assignment fullAssn) {
        List ptlList = mdl.allFactorsContaining(var);
        LogTableFactor ptl = new LogTableFactor(var);
        AssignmentIterator it = ptl.assignmentIterator();
        while (it.hasNext()) {
            Assignment varAssn = it.assignment();
            fullAssn.setValue(var, varAssn.get(var));
            ptl.setRawValue(varAssn, this.sumValues(ptlList, fullAssn));
            it.advance();
        }
        ptl.normalize();
        return ptl;
    }

    private double sumValues(List ptlList, Assignment assn) {
        double sum = 0.0;
        for (Factor ptl : ptlList) {
            sum += ptl.logValue(assn);
        }
        return sum;
    }
}

