package com.ibm.dltj.crf;

import com.ibm.dltj.DLTException;
import com.ibm.dltj.Messages;
import com.ibm.dltj.gloss.CRFLabelSet;
import com.ibm.dltj.gloss.CRFLearningRate;
import com.ibm.dltj.gloss.CRFStateFeatureGloss;
import com.ibm.dltj.gloss.CRFTransitionFeatureGloss;
import com.ibm.dltj.netgeneric.NetGeneric;
import com.ibm.dltj.util.ArrayUtils;
import com.ibm.dltj.util.ExpUtils;
import com.ibm.dltj.util.MathUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:dlt.jar:com/ibm/dltj/crf/ForwardBackward.class */
public class ForwardBackward {
    private double _Z;
    private final CRFLearningRate _learningRate;
    private final CRFLabelSet _labelSet;
    private final CRFStateFeatureStore _store;
    private static final Logger _logger = Logger.getLogger(ForwardBackward.class.getName());
    private double[][] _alpha = ArrayUtils.EMPTY_DOUBLE_DARRAY;
    private double[][] _beta = ArrayUtils.EMPTY_DOUBLE_DARRAY;
    private final ArrayList<ArrayList<CRFStateFeatureGloss>> _theta = new ArrayList<>();
    private final ArrayList<CRFTransitionFeatureGloss> _psi = new ArrayList<>();
    private double[][] _pTheta = ArrayUtils.EMPTY_DOUBLE_DARRAY;
    private double[][][] _pPsi = ArrayUtils.EMPTY_DOUBLE_TARRAY;
    private final ExecutorService _executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() + 1);

    static String getCopyright() {
        return "\n\n(C) Copyright IBM Corp. 2003, 2010.\n\n";
    }

    public ForwardBackward(CRFDictionary cRFDictionary) throws DLTException {
        if (cRFDictionary == null) {
            throw new IllegalArgumentException();
        }
        this._learningRate = cRFDictionary.getLearningRate();
        this._labelSet = cRFDictionary.getLabelSet();
        this._store = new CRFStateFeatureStore(cRFDictionary);
    }

    public void clear() {
        int min = Math.min(this._alpha.length, length());
        for (int i = 0; i < min; i++) {
            if (this._alpha[i] != null) {
                Arrays.fill(this._alpha[i], 0.0d);
            }
        }
        int min2 = Math.min(this._beta.length, length());
        for (int i2 = 0; i2 < min2; i2++) {
            if (this._beta[i2] != null) {
                Arrays.fill(this._beta[i2], 0.0d);
            }
        }
        this._Z = 0.0d;
        int min3 = Math.min(this._pTheta.length, length());
        for (int i3 = 0; i3 < min3; i3++) {
            if (this._pTheta[i3] != null) {
                Arrays.fill(this._pTheta[i3], 0.0d);
            }
        }
        int min4 = Math.min(this._pPsi.length, length());
        for (int i4 = 0; i4 < min4; i4++) {
            double[][] dArr = this._pPsi[i4];
            if (dArr != null) {
                for (double[] dArr2 : dArr) {
                    if (dArr2 != null) {
                        Arrays.fill(dArr2, 0.0d);
                    }
                }
            }
        }
        this._psi.clear();
        int min5 = Math.min(this._theta.size(), length());
        for (int i5 = 0; i5 < min5; i5++) {
            ArrayList<CRFStateFeatureGloss> arrayList = this._theta.get(i5);
            if (arrayList != null) {
                arrayList.clear();
            }
        }
    }

    public void close(long j) throws InterruptedException, DLTException {
        if (this._executor != null) {
            this._executor.shutdown();
            this._executor.awaitTermination(j, TimeUnit.MILLISECONDS);
        }
        this._store.close();
    }

    public void addStateFeature(int i, NetGeneric.IndexIterator indexIterator) throws DLTException {
        if (i < 0 || indexIterator == null) {
            throw new IllegalArgumentException();
        }
        while (i >= this._theta.size()) {
            this._theta.add(null);
        }
        ArrayList<CRFStateFeatureGloss> arrayList = this._theta.get(i);
        if (arrayList == null) {
            arrayList = new ArrayList<>();
            this._theta.set(i, arrayList);
        }
        arrayList.add(this._store.getValue(indexIterator));
    }

    public void addTransitionFeature(int i, CRFTransitionFeatureGloss cRFTransitionFeatureGloss) {
        if (i < 0) {
            throw new IllegalArgumentException();
        }
        if (cRFTransitionFeatureGloss == null) {
            return;
        }
        while (i >= this._psi.size()) {
            this._psi.add(null);
        }
        this._psi.set(i, cRFTransitionFeatureGloss);
    }

    public boolean process() throws Exception {
        if (weight()) {
            return probability();
        }
        return false;
    }

    private int length() {
        return this._theta.size();
    }

    private double theta(int i, int i2) {
        ArrayList<CRFStateFeatureGloss> arrayList;
        if (i < 0 || i >= this._theta.size() || (arrayList = this._theta.get(i)) == null || arrayList.isEmpty()) {
            return 0.0d;
        }
        double d = 0.0d;
        Iterator<CRFStateFeatureGloss> it = arrayList.iterator();
        while (it.hasNext()) {
            d += it.next().w(i2);
        }
        return d;
    }

    private double psi(int i, int i2, int i3) {
        CRFTransitionFeatureGloss cRFTransitionFeatureGloss;
        if (i < 0 || i >= this._psi.size() || (cRFTransitionFeatureGloss = this._psi.get(i)) == null) {
            return 0.0d;
        }
        return cRFTransitionFeatureGloss.w(i2, i3);
    }

    private boolean weight() throws Exception {
        Future future = null;
        Future future2 = null;
        try {
            try {
                try {
                    future = this._executor.submit(new Callable<Double>() { // from class: com.ibm.dltj.crf.ForwardBackward.1
                        /* JADX WARN: Can't rename method to resolve collision */
                        @Override // java.util.concurrent.Callable
                        public Double call() throws Exception {
                            return Double.valueOf(ForwardBackward.this.forward());
                        }
                    });
                    future2 = this._executor.submit(new Callable<Double>() { // from class: com.ibm.dltj.crf.ForwardBackward.2
                        /* JADX WARN: Can't rename method to resolve collision */
                        @Override // java.util.concurrent.Callable
                        public Double call() throws Exception {
                            return Double.valueOf(ForwardBackward.this.backward());
                        }
                    });
                    double doubleValue = ((Double) future.get()).doubleValue();
                    double doubleValue2 = ((Double) future2.get()).doubleValue();
                    this._Z = Math.max(doubleValue, doubleValue2);
                    if (this._Z >= 0.0d && 1.0d - (Math.min(doubleValue, doubleValue2) / Math.max(doubleValue, doubleValue2)) <= 0.1d) {
                        if (future != null) {
                            future.cancel(true);
                        }
                        if (future2 == null) {
                            return true;
                        }
                        future2.cancel(true);
                        return true;
                    }
                    if (_logger.isLoggable(Level.INFO)) {
                        _logger.log(Level.INFO, Messages.format("info.fb.Z: ", Double.toString(doubleValue), Double.toString(doubleValue2)));
                    }
                    if (future != null) {
                        future.cancel(true);
                    }
                    if (future2 != null) {
                        future2.cancel(true);
                    }
                    return false;
                } catch (ExecutionException e) {
                    throw new Exception(e.getCause());
                }
            } catch (InterruptedException e2) {
                Thread.currentThread().interrupt();
                if (future != null) {
                    future.cancel(true);
                }
                if (future2 == null) {
                    return true;
                }
                future2.cancel(true);
                return true;
            }
        } catch (Throwable th) {
            if (future != null) {
                future.cancel(true);
            }
            if (future2 != null) {
                future2.cancel(true);
            }
            throw th;
        }
    }

    double forward() {
        this._alpha = ArrayUtils.ensureCapacity(this._alpha, length());
        for (int i = 0; i < length(); i++) {
            if (this._alpha[i] == null || this._alpha[i].length < this._labelSet.size()) {
                this._alpha[i] = new double[this._labelSet.size()];
            }
        }
        int startId = this._labelSet.getStartId();
        for (int i2 = 0; i2 < this._labelSet.size(); i2++) {
            this._alpha[1][i2] = psi(1, startId, i2) + theta(1, i2);
        }
        double[] dArr = new double[this._labelSet.size()];
        for (int i3 = 2; i3 < length(); i3++) {
            double[] dArr2 = this._alpha[i3 - 1];
            double[] dArr3 = this._alpha[i3];
            for (int i4 = 0; i4 < this._labelSet.size(); i4++) {
                double theta = theta(i3, i4);
                for (int i5 = 0; i5 < this._labelSet.size(); i5++) {
                    dArr[i5] = dArr2[i5] + psi(i3, i5, i4) + theta;
                }
                dArr3[i4] = MathUtils.logsumexp(dArr, this._learningRate.threshold());
            }
        }
        return Math.max(this._alpha[length() - 1][this._labelSet.getFinalId()], this._learningRate.threshold());
    }

    double backward() {
        this._beta = ArrayUtils.ensureCapacity(this._beta, length());
        for (int i = 0; i < length(); i++) {
            if (this._beta[i] == null || this._beta[i].length < this._labelSet.size()) {
                this._beta[i] = new double[this._labelSet.size()];
            }
        }
        int finalId = this._labelSet.getFinalId();
        for (int i2 = 0; i2 < this._labelSet.size(); i2++) {
            this._beta[length() - 2][i2] = psi(length() - 1, i2, finalId) + theta(length() - 1, finalId);
        }
        double[] dArr = new double[this._labelSet.size()];
        double[] dArr2 = new double[this._labelSet.size()];
        for (int length = length() - 3; length >= 0; length--) {
            double[] dArr3 = this._beta[length];
            double[] dArr4 = this._beta[length + 1];
            for (int i3 = 0; i3 < this._labelSet.size(); i3++) {
                dArr2[i3] = theta(length + 1, i3);
            }
            for (int i4 = 0; i4 < this._labelSet.size(); i4++) {
                for (int i5 = 0; i5 < this._labelSet.size(); i5++) {
                    dArr[i5] = psi(length + 1, i4, i5) + dArr2[i5] + dArr4[i5];
                }
                dArr3[i4] = MathUtils.logsumexp(dArr, this._learningRate.threshold());
            }
        }
        return Math.max(this._beta[0][this._labelSet.getStartId()], this._learningRate.threshold());
    }

    private boolean probability() throws Exception {
        if (this._pTheta.length < length()) {
            this._pTheta = ArrayUtils.ensureCapacity(this._pTheta, length());
        }
        if (this._pPsi.length < length()) {
            this._pPsi = (double[][][]) ArrayUtils.ensureCapacity(this._pPsi, length());
        }
        for (int i = 1; i < length(); i++) {
            if (this._pTheta[i] == null || this._pTheta[i].length < this._labelSet.size()) {
                this._pTheta[i] = new double[this._labelSet.size()];
            }
            if (this._pPsi[i] == null || this._pPsi[i].length < this._labelSet.size()) {
                this._pPsi[i] = new double[this._labelSet.size()][this._labelSet.size()];
            }
        }
        boolean z = true;
        ArrayList arrayList = new ArrayList(length());
        try {
            try {
                try {
                    arrayList.add(this._executor.submit(new Callable<Boolean>() { // from class: com.ibm.dltj.crf.ForwardBackward.3
                        /* JADX WARN: Can't rename method to resolve collision */
                        @Override // java.util.concurrent.Callable
                        public Boolean call() throws Exception {
                            return Boolean.valueOf(ForwardBackward.this.probability1());
                        }
                    }));
                    arrayList.add(this._executor.submit(new Callable<Boolean>() { // from class: com.ibm.dltj.crf.ForwardBackward.4
                        /* JADX WARN: Can't rename method to resolve collision */
                        @Override // java.util.concurrent.Callable
                        public Boolean call() throws Exception {
                            return Boolean.valueOf(ForwardBackward.this.probabilityN());
                        }
                    }));
                    for (int i2 = 2; i2 < length() - 1; i2++) {
                        final int i3 = i2;
                        arrayList.add(this._executor.submit(new Callable<Boolean>() { // from class: com.ibm.dltj.crf.ForwardBackward.5
                            /* JADX WARN: Can't rename method to resolve collision */
                            @Override // java.util.concurrent.Callable
                            public Boolean call() throws Exception {
                                return Boolean.valueOf(ForwardBackward.this.probability(i3));
                            }
                        }));
                    }
                    Iterator it = arrayList.iterator();
                    while (it.hasNext()) {
                        z &= ((Boolean) ((Future) it.next()).get()).booleanValue();
                    }
                    if (!arrayList.isEmpty()) {
                        Iterator it2 = arrayList.iterator();
                        while (it2.hasNext()) {
                            ((Future) it2.next()).cancel(true);
                        }
                    }
                } catch (ExecutionException e) {
                    throw new Exception(e.getCause());
                }
            } catch (InterruptedException e2) {
                Thread.currentThread().interrupt();
                if (!arrayList.isEmpty()) {
                    Iterator it3 = arrayList.iterator();
                    while (it3.hasNext()) {
                        ((Future) it3.next()).cancel(true);
                    }
                }
            }
            return z;
        } catch (Throwable th) {
            if (!arrayList.isEmpty()) {
                Iterator it4 = arrayList.iterator();
                while (it4.hasNext()) {
                    ((Future) it4.next()).cancel(true);
                }
            }
            throw th;
        }
    }

    boolean probability1() {
        double[] dArr = this._pTheta[1];
        double[][] dArr2 = this._pPsi[1];
        ExpUtils expUtils = new ExpUtils();
        int startId = this._labelSet.getStartId();
        for (int i = 0; i < this._labelSet.size(); i++) {
            double exp = expUtils.exp(((psi(1, startId, i) + theta(1, i)) + this._beta[1][i]) - this._Z);
            int i2 = i;
            dArr[i2] = dArr[i2] + exp;
            dArr2[startId][i] = exp;
        }
        return verify(dArr);
    }

    boolean probabilityN() {
        int length = length() - 1;
        double[] dArr = this._pTheta[length];
        double[][] dArr2 = this._pPsi[length];
        ExpUtils expUtils = new ExpUtils();
        int finalId = this._labelSet.getFinalId();
        double theta = theta(length, finalId);
        for (int i = 0; i < this._labelSet.size(); i++) {
            double exp = expUtils.exp(((this._alpha[length - 1][i] + psi(length, i, finalId)) + theta) - this._Z);
            dArr[finalId] = dArr[finalId] + exp;
            dArr2[i][finalId] = exp;
        }
        return verify(dArr);
    }

    boolean probability(int i) {
        if (i < 1 || i >= length()) {
            throw new IllegalArgumentException();
        }
        double[] dArr = this._pTheta[i];
        double[][] dArr2 = this._pPsi[i];
        ExpUtils expUtils = new ExpUtils();
        for (int i2 = 0; i2 < this._labelSet.size(); i2++) {
            double theta = theta(i, i2);
            for (int i3 = 0; i3 < this._labelSet.size(); i3++) {
                double exp = expUtils.exp((((this._alpha[i - 1][i3] + psi(i, i3, i2)) + theta) + this._beta[i][i2]) - this._Z);
                int i4 = i2;
                dArr[i4] = dArr[i4] + exp;
                dArr2[i3][i2] = exp;
            }
        }
        return verify(dArr);
    }

    private static boolean verify(double[] dArr) {
        double sum = ArrayUtils.sum(dArr);
        if (Math.abs(1.0d - sum) <= 1.0E-4d) {
            return true;
        }
        System.out.println(sum);
        return false;
    }

    public void update(int i, int i2, int i3) {
        if (i < 1 || i >= length()) {
            throw new IllegalArgumentException();
        }
        if (i2 < 0 || i2 >= this._labelSet.size()) {
            throw new IllegalArgumentException();
        }
        if (i3 < 0 || i3 >= this._labelSet.size()) {
            throw new IllegalArgumentException();
        }
        updatePsi(i, i2, i3);
        updateTheta(i, i3);
    }

    private void updatePsi(int i, int i2, int i3) {
        CRFTransitionFeatureGloss cRFTransitionFeatureGloss;
        double[][] dArr;
        if (this._psi.size() <= i || (cRFTransitionFeatureGloss = this._psi.get(i)) == null || (dArr = this._pPsi[i]) == null || dArr.length == 0) {
            return;
        }
        cRFTransitionFeatureGloss.ensureCapacity(dArr.length);
        for (int i4 = 0; i4 < dArr.length; i4++) {
            for (int i5 = 0; i5 < dArr.length; i5++) {
                if (i4 == i2 && i5 == i3) {
                    cRFTransitionFeatureGloss.g(i4, i5, cRFTransitionFeatureGloss.g(i4, i5) + (1.0d - dArr[i4][i5]), this._learningRate.threshold());
                } else {
                    cRFTransitionFeatureGloss.g(i4, i5, cRFTransitionFeatureGloss.g(i4, i5) + (0.0d - dArr[i4][i5]), this._learningRate.threshold());
                }
            }
        }
    }

    private void updateTheta(int i, int i2) {
        ArrayList<CRFStateFeatureGloss> arrayList;
        double[] dArr;
        if (this._theta.size() <= i || (arrayList = this._theta.get(i)) == null || arrayList.isEmpty() || (dArr = this._pTheta[i]) == null || dArr.length == 0) {
            return;
        }
        Iterator<CRFStateFeatureGloss> it = arrayList.iterator();
        while (it.hasNext()) {
            CRFStateFeatureGloss next = it.next();
            next.ensureCapacity(dArr.length);
            for (int i3 = 0; i3 < dArr.length; i3++) {
                if (i3 == i2) {
                    next.g(i3, next.g(i3) + (1.0d - dArr[i3]), this._learningRate.threshold());
                } else {
                    next.g(i3, next.g(i3) + (0.0d - dArr[i3]), this._learningRate.threshold());
                }
            }
        }
    }

    public void commit() {
        commitPsi();
        commitTheta();
    }

    private void commitPsi() {
        for (int i = 0; i < this._psi.size(); i++) {
            CRFTransitionFeatureGloss cRFTransitionFeatureGloss = this._psi.get(i);
            if (cRFTransitionFeatureGloss != null) {
                for (int i2 = 0; i2 < this._labelSet.size(); i2++) {
                    for (int i3 = 0; i3 < this._labelSet.size(); i3++) {
                        double eta = this._learningRate.eta() * cRFTransitionFeatureGloss.g(i2, i3);
                        if (Math.abs(eta) >= this._learningRate.threshold()) {
                            double w = cRFTransitionFeatureGloss.w(i2, i3) + eta;
                            if (w > 0.0d) {
                                cRFTransitionFeatureGloss.w(i2, i3, (float) Math.max(0.0d, w - (this._learningRate.u() + cRFTransitionFeatureGloss.q(i2, i3))), this._learningRate.threshold());
                            } else if (w < 0.0d) {
                                cRFTransitionFeatureGloss.w(i2, i3, (float) Math.min(0.0d, w + (this._learningRate.u() - cRFTransitionFeatureGloss.q(i2, i3))), this._learningRate.threshold());
                            }
                            cRFTransitionFeatureGloss.q(i2, i3, (float) (cRFTransitionFeatureGloss.q(i2, i3) + (cRFTransitionFeatureGloss.w(i2, i3) - w)), this._learningRate.threshold());
                        }
                    }
                }
                cRFTransitionFeatureGloss.clear();
            }
        }
    }

    private void commitTheta() {
        for (int i = 0; i < this._theta.size(); i++) {
            ArrayList<CRFStateFeatureGloss> arrayList = this._theta.get(i);
            if (arrayList != null && !arrayList.isEmpty()) {
                Iterator<CRFStateFeatureGloss> it = arrayList.iterator();
                while (it.hasNext()) {
                    CRFStateFeatureGloss next = it.next();
                    for (int i2 = 0; i2 < this._labelSet.size(); i2++) {
                        double eta = this._learningRate.eta() * next.g(i2);
                        if (Math.abs(eta) >= this._learningRate.threshold()) {
                            double w = next.w(i2) + eta;
                            if (w > 0.0d) {
                                next.w(i2, Math.max(0.0d, w - (this._learningRate.u() + next.q(i2))), this._learningRate.threshold());
                            } else if (w < 0.0d) {
                                next.w(i2, Math.min(0.0d, w + (this._learningRate.u() - next.q(i2))), this._learningRate.threshold());
                            }
                            next.q(i2, next.q(i2) + (next.w(i2) - w), this._learningRate.threshold());
                        }
                    }
                    next.clear();
                }
            }
        }
    }
}
