/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.training.loss.Loss;

public class QuantileL1Loss
extends Loss {
    private Number quantile;

    public QuantileL1Loss(float quantile) {
        this("QuantileL1Loss", quantile);
    }

    public QuantileL1Loss(String name, float quantile) {
        super(name);
        this.quantile = Float.valueOf(quantile);
    }

    @Override
    public NDArray evaluate(NDList labels, NDList predictions) {
        NDArray pred = predictions.singletonOrThrow();
        NDArray labelReshaped = labels.singletonOrThrow().reshape(pred.getShape());
        NDArray loss = pred.sub(labelReshaped).mul(labelReshaped.lte(pred).toType(DataType.FLOAT32, false).sub(this.quantile)).abs().mul(2);
        return loss.mean();
    }
}

