/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.engine.rust;

import ai.djl.engine.rust.RsNDArray;
import ai.djl.engine.rust.RsNDManager;
import ai.djl.engine.rust.RustLibrary;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.index.full.NDIndexFullSlice;
import ai.djl.ndarray.index.full.NDIndexFullTake;
import ai.djl.ndarray.types.Shape;
import java.util.Arrays;

public class RsNDArrayIndexer
extends NDArrayIndexer {
    private RsNDManager manager;

    RsNDArrayIndexer(RsNDManager manager) {
        this.manager = manager;
    }

    public NDArray get(NDArray array, NDIndexFullPick fullPick) {
        try (NDScope ignore = new NDScope();){
            long handle = (Long)this.manager.from(array).getHandle();
            long pickHandle = (Long)this.manager.from(fullPick.getIndices()).getHandle();
            long newHandle = RustLibrary.pick(handle, pickHandle, fullPick.getAxis());
            RsNDArray ret = new RsNDArray(this.manager, newHandle);
            NDScope.unregister((NDArray)ret);
            RsNDArray rsNDArray = ret;
            return rsNDArray;
        }
    }

    public NDArray get(NDArray array, NDIndexFullTake fullTake) {
        try (NDScope ignore = new NDScope();){
            long handle = (Long)this.manager.from(array).getHandle();
            long takeHandle = (Long)this.manager.from(fullTake.getIndices()).getHandle();
            RsNDArray ret = new RsNDArray(this.manager, RustLibrary.take(handle, takeHandle));
            NDScope.unregister((NDArray)ret);
            RsNDArray rsNDArray = ret;
            return rsNDArray;
        }
    }

    public NDArray get(NDArray array, NDIndexFullSlice fullSlice) {
        long[] min = fullSlice.getMin();
        long[] max = fullSlice.getMax();
        long[] step = fullSlice.getStep();
        long[] s = (long[])array.getShape().getShape().clone();
        if (Arrays.stream(step).anyMatch(i -> i != 1L)) {
            throw new UnsupportedOperationException("only step 1 is supported");
        }
        for (int i2 = 0; i2 < min.length; ++i2) {
            if (min[i2] < max[i2] && min[i2] < s[i2]) continue;
            Shape shape = fullSlice.getSqueezedShape();
            return this.manager.create(shape, array.getDataType(), array.getDevice());
        }
        try (NDScope ignore = new NDScope();){
            long handle = (Long)this.manager.from(array).getHandle();
            long tmp = RustLibrary.fullSlice(handle, min, max, step);
            long newHandle = RustLibrary.reshape(tmp, fullSlice.getSqueezedShape().getShape());
            RustLibrary.deleteTensor(tmp);
            RsNDArray ret = new RsNDArray(this.manager, newHandle, array.getDataType());
            NDScope.unregister((NDArray)ret);
            RsNDArray rsNDArray = ret;
            return rsNDArray;
        }
    }

    public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void set(NDArray array, NDIndexFullSlice fullSlice, Number value) {
        this.set(array, fullSlice, array.getManager().create(value));
    }
}

