/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.nlp.generate;

import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.generate.CausalLMOutput;
import ai.djl.modality.nlp.generate.ContrastiveBatchTensorList;
import ai.djl.modality.nlp.generate.SearchConfig;
import ai.djl.modality.nlp.generate.SeqBatchScheduler;
import ai.djl.modality.nlp.generate.SeqBatcher;
import ai.djl.modality.nlp.generate.StepGeneration;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslateException;
import java.util.Arrays;
import java.util.function.Function;
import java.util.stream.Collectors;

public class ContrastiveSeqBatchScheduler
extends SeqBatchScheduler {
    public ContrastiveSeqBatchScheduler(Predictor<NDList, CausalLMOutput> lmBlock, SearchConfig config) {
        super(lmBlock, config);
    }

    @Override
    public SeqBatcher initForward(NDArray inputIds, NDArray batchUids) throws TranslateException {
        try (NDScope scope = new NDScope();){
            scope.suppressNotUsedWarning();
            this.manager = inputIds.getManager();
            NDArray initOffSets = ContrastiveSeqBatchScheduler.computeOffSets(inputIds, this.config);
            NDArray attentionMask = ContrastiveSeqBatchScheduler.computeAttentionMask(inputIds, this.config);
            NDArray positionIds = ContrastiveSeqBatchScheduler.computePositionIds(inputIds, initOffSets, 0L, 1);
            CausalLMOutput output = (CausalLMOutput)this.predictor.predict(new NDList(inputIds, positionIds, attentionMask));
            NDArray lastLogits = output.getLogits().get(":, -1, :", new Object[0]);
            long[] seqDimOrder = new long[28];
            Arrays.fill(seqDimOrder, 0, 3, 1L);
            seqDimOrder[3] = -1L;
            Arrays.fill(seqDimOrder, 4, seqDimOrder.length, 2L);
            ContrastiveBatchTensorList batchTensorList = new ContrastiveBatchTensorList(inputIds, attentionMask, output.getHiddenState(), lastLogits, output.getPastKeyValuesList(), seqDimOrder);
            SeqBatcher ret = new SeqBatcher(batchTensorList, batchUids, initOffSets, this.manager);
            NDScope.unregister(output.getPastKeyValuesList());
            NDScope.unregister(output.getHiddenState(), attentionMask, lastLogits);
            NDScope.unregister(ret.offSets, ret.batchUid);
            SeqBatcher seqBatcher = ret;
            return seqBatcher;
        }
    }

    @Override
    public NDArray inferenceCall() throws TranslateException {
        NDArray outputIds;
        try (NDScope scope = new NDScope();){
            scope.suppressNotUsedWarning();
            NDArray logits = ((ContrastiveBatchTensorList)this.seqBatcher.getData()).getLogits();
            NDArray topKIds = (NDArray)logits.topK(this.config.getK(), -1, true, false).get(1);
            ContrastiveBatchTensorList searchState = (ContrastiveBatchTensorList)this.seqBatcher.data;
            NDArray candidateInputIds = topKIds.flatten().reshape(-1L, 1L);
            assert (candidateInputIds.getDataType() == DataType.INT64) : "inputIds datatype should be int64";
            assert (candidateInputIds.getShape().getShape().length == 2) : "shape not right";
            NDList kCopyPastKeyValues = new NDList(searchState.getPastKeyValues().stream().map(ndarray -> ndarray.repeat(0, this.config.getK())).collect(Collectors.toList()));
            assert (((NDArray)kCopyPastKeyValues.get(0)).getDataType() == DataType.FLOAT32) : "inputIds datatype should be Float32";
            long numBatch = topKIds.getShape().get(0);
            NDArray kCopyPastAttentionMask = searchState.getPastAttentionMask().repeat(0, this.config.getK());
            kCopyPastAttentionMask = kCopyPastAttentionMask.concat(this.manager.ones(new Shape(numBatch * (long)this.config.getK(), 1L), DataType.INT64), 1);
            assert (((NDArray)kCopyPastKeyValues.get(0)).getShape().get(2) + 1L == kCopyPastAttentionMask.getShape().getLastDimension()) : "attentionMask_seq = past_seq + new_input_seq";
            NDArray candidatePositionIds = ContrastiveSeqBatchScheduler.computePositionIds(candidateInputIds, this.seqBatcher.offSets, searchState.getPastOutputIds().getShape().getLastDimension(), this.config.getK());
            NDList modelInputs = new NDList(candidateInputIds, candidatePositionIds, kCopyPastAttentionMask);
            modelInputs.addAll(kCopyPastKeyValues);
            CausalLMOutput candidateOutput = (CausalLMOutput)this.predictor.predict(modelInputs);
            NDList generatedOutput = StepGeneration.constrastiveStepGeneration(topKIds, logits, searchState.getPastHiddenStates(), candidateOutput.getHiddenState(), this.seqBatcher.offSets, this.config.getAlpha());
            long logitsDim = logits.getShape().get(1);
            long numHeads = ((NDArray)searchState.getPastKeyValues().get(0)).getShape().get(1);
            long kvDim = ((NDArray)searchState.getPastKeyValues().get(0)).getShape().get(3);
            long currentSeqLength = searchState.getPastOutputIds().getShape().get(1);
            long hiddenDim = searchState.getPastHiddenStates().getShape().get(2);
            NDArray select = (NDArray)generatedOutput.get(1);
            NDIndex selectIndex = new NDIndex("{}, {}, ...", this.manager.arange(0.0f, numBatch, 1.0f, DataType.INT64), select.flatten());
            NDArray nextLogits = candidateOutput.getLogits().reshape(numBatch, this.config.getK(), logitsDim).get(selectIndex);
            Function<NDArray, NDArray> fn = ndarray -> ndarray.reshape(numBatch, this.config.getK(), numHeads, currentSeqLength + 1L, kvDim).get(selectIndex);
            NDList nextPastKeyValue = new NDList(candidateOutput.getPastKeyValuesList().stream().map(fn).collect(Collectors.toList()));
            NDArray newHiddenState = candidateOutput.getHiddenState();
            assert (newHiddenState.getManager() == this.manager) : "possible leaky memory";
            NDArray nextPastHiddenStates = searchState.getPastHiddenStates().concat(newHiddenState.reshape(numBatch, this.config.getK(), 1L, hiddenDim).get(selectIndex), 1);
            outputIds = (NDArray)generatedOutput.get(0);
            NDArray nextOutputIds = searchState.getPastOutputIds().concat(outputIds, 1);
            NDArray nextPastAttentionMask = searchState.getPastAttentionMask().concat(this.manager.ones(new Shape(numBatch, 1L), DataType.INT64), 1);
            ++this.seqBatcher.seqLength;
            this.seqBatcher.data = new ContrastiveBatchTensorList(nextOutputIds, nextPastAttentionMask, nextPastHiddenStates, nextLogits, nextPastKeyValue, searchState.getSeqDimOrder());
            this.seqBatcher.exitCriteria(outputIds, this.config.getMaxSeqLength(), this.config.getEosTokenId());
            NDScope.unregister(nextOutputIds);
            NDScope.unregister(nextPastAttentionMask);
            NDScope.unregister(nextPastHiddenStates);
            NDScope.unregister(nextLogits);
            NDScope.unregister(nextPastKeyValue);
            NDScope.unregister(outputIds);
        }
        return outputIds;
    }
}

