/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.quantization.quantizer;

import java.io.IOException;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
import org.opensearch.knn.quantization.quantizer.BitPacker;
import org.opensearch.knn.quantization.quantizer.Quantizer;
import org.opensearch.knn.quantization.quantizer.QuantizerHelper;
import org.opensearch.knn.quantization.sampler.Sampler;
import org.opensearch.knn.quantization.sampler.SamplerType;
import org.opensearch.knn.quantization.sampler.SamplingFactory;

public class OneBitScalarQuantizer
implements Quantizer<float[], byte[]> {
    private final int samplingSize;
    private static final boolean IS_TRAINING_REQUIRED = true;
    private final Sampler sampler;
    private static final int DEFAULT_SAMPLE_SIZE = 25000;

    public OneBitScalarQuantizer() {
        this(25000, SamplingFactory.getSampler(SamplerType.RESERVOIR));
    }

    public OneBitScalarQuantizer(int samplingSize, Sampler sampler) {
        this.samplingSize = samplingSize;
        this.sampler = sampler;
    }

    @Override
    public QuantizationState train(TrainingRequest<float[]> trainingRequest) throws IOException {
        int[] sampledDocIds = this.sampler.sample(trainingRequest.getTotalNumberOfVectors(), this.samplingSize);
        float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(trainingRequest, sampledDocIds);
        return new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), meanThresholds);
    }

    @Override
    public void quantize(float[] vector, QuantizationState state, QuantizationOutput<byte[]> output) {
        if (vector == null) {
            throw new IllegalArgumentException("Vector to quantize must not be null.");
        }
        this.validateState(state);
        int vectorLength = vector.length;
        OneBitScalarQuantizationState binaryState = (OneBitScalarQuantizationState)state;
        float[] thresholds = binaryState.getMeanThresholds();
        if (thresholds == null || thresholds.length != vectorLength) {
            throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
        }
        output.prepareQuantizedVector(vectorLength);
        BitPacker.quantizeAndPackBits(vector, thresholds, output.getQuantizedVector());
    }

    private void validateState(QuantizationState state) {
        if (!(state instanceof OneBitScalarQuantizationState)) {
            throw new IllegalArgumentException("Quantization state must be of type OneBitScalarQuantizationState.");
        }
    }
}

