/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.plugin.script;

import java.math.BigInteger;
import java.util.List;
import java.util.Objects;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.VectorUtil;
import org.opensearch.knn.common.KNNValidationUtil;
import org.opensearch.knn.index.KNNVectorScriptDocValues;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;

public class KNNScoringUtil {
    private static Logger logger = LogManager.getLogger(KNNScoringUtil.class);

    private static void requireEqualDimension(float[] queryVector, float[] inputVector) {
        Objects.requireNonNull(queryVector);
        Objects.requireNonNull(inputVector);
        if (queryVector.length != inputVector.length) {
            String errorMessage = String.format("query vector dimension mismatch. Expected: %d, Given: %d", inputVector.length, queryVector.length);
            throw new IllegalArgumentException(errorMessage);
        }
    }

    public static float l2Squared(float[] queryVector, float[] inputVector) {
        return VectorUtil.squareDistance((float[])queryVector, (float[])inputVector);
    }

    private static float[] toFloat(List<Number> inputVector, VectorDataType vectorDataType) {
        Objects.requireNonNull(inputVector);
        float[] value = new float[inputVector.size()];
        int index = 0;
        for (Number val : inputVector) {
            float floatValue = val.floatValue();
            if (VectorDataType.BYTE == vectorDataType) {
                KNNValidationUtil.validateByteVectorValue(floatValue);
            }
            value[index++] = floatValue;
        }
        return value;
    }

    public static float l2Squared(List<Number> queryVector, KNNVectorScriptDocValues docValues) {
        return KNNScoringUtil.l2Squared(KNNScoringUtil.toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue());
    }

    public static float cosinesimilOptimized(float[] queryVector, float[] inputVector, float normQueryVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        float dotProduct = 0.0f;
        float normInputVector = 0.0f;
        for (int i = 0; i < queryVector.length; ++i) {
            dotProduct += queryVector[i] * inputVector[i];
            normInputVector += inputVector[i] * inputVector[i];
        }
        float normalizedProduct = normQueryVector * normInputVector;
        if (normalizedProduct == 0.0f) {
            logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end");
            return 0.0f;
        }
        return (float)((double)dotProduct / Math.sqrt(normalizedProduct));
    }

    public static float cosineSimilarity(List<Number> queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) {
        float[] inputVector = KNNScoringUtil.toFloat(queryVector, docValues.getVectorDataType());
        SpaceType.COSINESIMIL.validateVector(inputVector);
        return KNNScoringUtil.cosinesimilOptimized(inputVector, docValues.getValue(), queryVectorMagnitude.floatValue());
    }

    public static float cosinesimil(float[] queryVector, float[] inputVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        try {
            return VectorUtil.cosine((float[])queryVector, (float[])inputVector);
        }
        catch (AssertionError | IllegalArgumentException e) {
            logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end");
            return 0.0f;
        }
    }

    public static float cosineSimilarity(List<Number> queryVector, KNNVectorScriptDocValues docValues) {
        float[] inputVector = KNNScoringUtil.toFloat(queryVector, docValues.getVectorDataType());
        SpaceType.COSINESIMIL.validateVector(inputVector);
        return KNNScoringUtil.cosinesimil(inputVector, docValues.getValue());
    }

    public static float calculateHammingBit(BigInteger queryBigInteger, BigInteger inputBigInteger) {
        return inputBigInteger.xor(queryBigInteger).bitCount();
    }

    public static float calculateHammingBit(Long queryLong, Long inputLong) {
        return Long.bitCount(queryLong ^ inputLong);
    }

    public static float l1Norm(float[] queryVector, float[] inputVector) {
        float distance = 0.0f;
        for (int i = 0; i < inputVector.length; ++i) {
            float diff = queryVector[i] - inputVector[i];
            distance += Math.abs(diff);
        }
        return distance;
    }

    public static float l1Norm(List<Number> queryVector, KNNVectorScriptDocValues docValues) {
        return KNNScoringUtil.l1Norm(KNNScoringUtil.toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue());
    }

    public static float lInfNorm(float[] queryVector, float[] inputVector) {
        float distance = 0.0f;
        for (int i = 0; i < inputVector.length; ++i) {
            float diff = queryVector[i] - inputVector[i];
            distance = Math.max(Math.abs(diff), distance);
        }
        return distance;
    }

    public static float lInfNorm(List<Number> queryVector, KNNVectorScriptDocValues docValues) {
        return KNNScoringUtil.lInfNorm(KNNScoringUtil.toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue());
    }

    public static float innerProduct(float[] queryVector, float[] inputVector) {
        return VectorUtil.dotProduct((float[])queryVector, (float[])inputVector);
    }

    public static float innerProduct(List<Number> queryVector, KNNVectorScriptDocValues docValues) {
        return KNNScoringUtil.innerProduct(KNNScoringUtil.toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue());
    }
}

