/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.transformer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.training.loss.Loss;

public class BertNextSentenceLoss
extends Loss {
    private int labelIdx;
    private int nextSentencePredictionIdx;

    public BertNextSentenceLoss(int labelIdx, int nextSentencePredictionIdx) {
        super("BertNSLoss");
        this.labelIdx = labelIdx;
        this.nextSentencePredictionIdx = nextSentencePredictionIdx;
    }

    @Override
    public NDArray evaluate(NDList labels, NDList predictions) {
        try (NDManager scope = NDManager.subManagerOf(labels);){
            scope.tempAttachAll(labels, predictions);
            NDArray label = ((NDArray)labels.get(this.labelIdx)).toType(DataType.FLOAT32, false);
            NDArray logPredictions = (NDArray)predictions.get(this.nextSentencePredictionIdx);
            NDArray oneHotLabels = label.oneHot(2);
            NDArray logPredictionForLabels = oneHotLabels.mul(logPredictions);
            NDArray summedPredictions = logPredictionForLabels.sum(new int[]{1});
            NDArray perExampleLoss = summedPredictions.mul(Float.valueOf(-1.0f));
            NDArray result = perExampleLoss.mean();
            NDArray nDArray = scope.ret(result);
            return nDArray;
        }
    }

    public NDArray accuracy(NDList labels, NDList predictions) {
        try (NDManager scope = NDManager.subManagerOf(labels);){
            scope.tempAttachAll(labels, predictions);
            NDArray label = (NDArray)labels.get(this.labelIdx);
            NDArray predictionLogProbs = (NDArray)predictions.get(this.nextSentencePredictionIdx);
            NDArray prediction = predictionLogProbs.argMax(1).toType(DataType.INT32, false);
            NDArray equalCount = label.eq(prediction).sum().toType(DataType.FLOAT32, false);
            NDArray result = equalCount.div(label.getShape().size());
            NDArray nDArray = scope.ret(result);
            return nDArray;
        }
    }
}

