/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote;

import com.google.common.annotations.VisibleForTesting;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.TokenBucket;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.script.ScriptService;

@Function(value=FunctionName.REMOTE)
public class RemoteModel
implements Predictable {
    @Generated
    private static final Logger log = LogManager.getLogger(RemoteModel.class);
    public static final String CLUSTER_SERVICE = "cluster_service";
    public static final String SCRIPT_SERVICE = "script_service";
    public static final String CLIENT = "client";
    public static final String XCONTENT_REGISTRY = "xcontent_registry";
    public static final String RATE_LIMITER = "rate_limiter";
    public static final String USER_RATE_LIMITER_MAP = "user_rate_limiter_map";
    public static final String GUARDRAILS = "guardrails";
    private RemoteConnectorExecutor connectorExecutor;

    @VisibleForTesting
    RemoteConnectorExecutor getConnectorExecutor() {
        return this.connectorExecutor;
    }

    @Override
    public MLOutput predict(MLInput mlInput, MLModel model) {
        throw new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models/" + model.getModelId() + "/_deploy");
    }

    @Override
    public MLOutput predict(MLInput mlInput) {
        if (!this.isModelReady()) {
            throw new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models/<model_id>/_deploy");
        }
        try {
            return this.connectorExecutor.executePredict(mlInput);
        }
        catch (RuntimeException e) {
            log.error("Failed to call remote model.", (Throwable)e);
            throw e;
        }
        catch (Throwable e) {
            log.error("Failed to call remote model.", e);
            throw new MLException(e);
        }
    }

    @Override
    public void close() {
        this.connectorExecutor = null;
    }

    @Override
    public boolean isModelReady() {
        return this.connectorExecutor != null;
    }

    @Override
    public void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
        try {
            Connector connector = model.getConnector().cloneConnector();
            connector.decrypt(credential -> encryptor.decrypt((String)credential));
            this.connectorExecutor = (RemoteConnectorExecutor)MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
            this.connectorExecutor.setScriptService((ScriptService)params.get(SCRIPT_SERVICE));
            this.connectorExecutor.setClusterService((ClusterService)params.get(CLUSTER_SERVICE));
            this.connectorExecutor.setClient((Client)params.get(CLIENT));
            this.connectorExecutor.setXContentRegistry((NamedXContentRegistry)params.get(XCONTENT_REGISTRY));
            this.connectorExecutor.setRateLimiter((TokenBucket)params.get(RATE_LIMITER));
            this.connectorExecutor.setUserRateLimiterMap((Map)params.get(USER_RATE_LIMITER_MAP));
            this.connectorExecutor.setMlGuard((MLGuard)params.get(GUARDRAILS));
        }
        catch (RuntimeException e) {
            log.error("Failed to init remote model.", (Throwable)e);
            throw e;
        }
        catch (Throwable e) {
            log.error("Failed to init remote model.", e);
            throw new MLException(e);
        }
    }
}

