/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.onnxruntime.engine;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.ndarray.NDManager;
import ai.djl.onnxruntime.engine.OrtModel;
import ai.djl.onnxruntime.engine.OrtNDManager;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtLoggingLevel;
import ai.onnxruntime.OrtSession;

public final class OrtEngine
extends Engine {
    public static final String ENGINE_NAME = "OnnxRuntime";
    static final int RANK = 10;
    private OrtEnvironment env;
    private Engine alternativeEngine;
    private boolean initialized;

    private OrtEngine() {
        OrtEnvironment.ThreadingOptions options = new OrtEnvironment.ThreadingOptions();
        try {
            Integer interOpThreads = Integer.getInteger("ai.djl.onnxruntime.num_interop_threads");
            Integer intraOpsThreads = Integer.getInteger("ai.djl.onnxruntime.num_threads");
            if (interOpThreads != null) {
                options.setGlobalInterOpNumThreads(interOpThreads.intValue());
            }
            if (intraOpsThreads != null) {
                options.setGlobalIntraOpNumThreads(intraOpsThreads.intValue());
            }
            OrtLoggingLevel logging = OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING;
            String name = "ort-java";
            this.env = OrtEnvironment.getEnvironment((OrtLoggingLevel)logging, (String)name, (OrtEnvironment.ThreadingOptions)options);
        }
        catch (OrtException e) {
            options.close();
            throw new AssertionError("Failed to config OrtEnvironment", e);
        }
    }

    static Engine newInstance() {
        return new OrtEngine();
    }

    OrtEnvironment getEnv() {
        return this.env;
    }

    public Engine getAlternativeEngine() {
        if (!this.initialized && !Boolean.getBoolean("ai.djl.onnx.disable_alternative")) {
            Engine engine = Engine.getInstance();
            if (engine.getRank() < this.getRank()) {
                this.alternativeEngine = engine;
            }
            this.initialized = true;
        }
        return this.alternativeEngine;
    }

    public String getEngineName() {
        return ENGINE_NAME;
    }

    public int getRank() {
        return 10;
    }

    public String getVersion() {
        return "1.17.1";
    }

    public boolean hasCapability(String capability) {
        if ("MKL".equals(capability)) {
            return true;
        }
        if ("CUDA".equals(capability)) {
            boolean bl;
            OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
            try {
                sessionOptions.addCUDA();
                bl = true;
            }
            catch (Throwable throwable) {
                try {
                    try {
                        sessionOptions.close();
                    }
                    catch (Throwable throwable2) {
                        throwable.addSuppressed(throwable2);
                    }
                    throw throwable;
                }
                catch (OrtException e) {
                    return false;
                }
            }
            sessionOptions.close();
            return bl;
        }
        return false;
    }

    public Model newModel(String name, Device device) {
        return new OrtModel(name, this.newBaseManager(device), this.env);
    }

    public NDManager newBaseManager() {
        return this.newBaseManager(null);
    }

    public NDManager newBaseManager(Device device) {
        return OrtNDManager.getSystemManager().newSubManager(device);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append(this.getEngineName()).append(':').append(this.getVersion()).append(", ");
        sb.append(this.getEngineName()).append(':').append(this.getVersion()).append(", capabilities: [\n\tMKL");
        if (this.hasCapability("CUDA")) {
            sb.append(",\n\t").append("CUDA");
        }
        sb.append(']');
        return sb.toString();
    }
}

