/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.plugin.transport;

import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Strings;
import org.opensearch.common.ValidationException;
import org.opensearch.common.collect.ImmutableOpenMap;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.plugin.transport.TrainingJobRouteDecisionInfoAction;
import org.opensearch.knn.plugin.transport.TrainingJobRouteDecisionInfoNodeResponse;
import org.opensearch.knn.plugin.transport.TrainingJobRouteDecisionInfoRequest;
import org.opensearch.knn.plugin.transport.TrainingJobRouteDecisionInfoResponse;
import org.opensearch.knn.plugin.transport.TrainingModelRequest;
import org.opensearch.knn.plugin.transport.TrainingModelResponse;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportRequestOptions;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;

public class TrainingJobRouterTransportAction
extends HandledTransportAction<TrainingModelRequest, TrainingModelResponse> {
    private final TransportService transportService;
    private final ClusterService clusterService;
    private final Client client;

    @Inject
    public TrainingJobRouterTransportAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, Client client) {
        super("cluster:admin/knn_training_job_router_action", transportService, actionFilters, TrainingModelRequest::new);
        this.clusterService = clusterService;
        this.client = client;
        this.transportService = transportService;
    }

    protected void doExecute(Task task, TrainingModelRequest request, ActionListener<TrainingModelResponse> listener) {
        this.getTrainingIndexSizeInKB(request, (ActionListener<Integer>)ActionListener.wrap(size -> {
            request.setTrainingDataSizeInKB((int)size);
            this.routeRequest(request, listener);
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    protected void routeRequest(TrainingModelRequest request, ActionListener<TrainingModelResponse> listener) {
        this.client.execute((ActionType)TrainingJobRouteDecisionInfoAction.INSTANCE, (ActionRequest)new TrainingJobRouteDecisionInfoRequest(new String[0]), ActionListener.wrap(response -> {
            DiscoveryNode node = this.selectNode(request.getPreferredNodeId(), (TrainingJobRouteDecisionInfoResponse)((Object)response));
            if (node == null) {
                ValidationException exception = new ValidationException();
                exception.addValidationError("Cluster does not have capacity to train");
                listener.onFailure((Exception)exception);
                return;
            }
            this.transportService.sendRequest(node, "cluster:admin/knn_training_model_action", (TransportRequest)request, TransportRequestOptions.EMPTY, (TransportResponseHandler)new ActionListenerResponseHandler(listener, TrainingModelResponse::new));
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    protected DiscoveryNode selectNode(String preferredNode, TrainingJobRouteDecisionInfoResponse jobInfo) {
        DiscoveryNode selectedNode = null;
        ImmutableOpenMap eligibleNodes = this.clusterService.state().nodes().getDataNodes();
        for (TrainingJobRouteDecisionInfoNodeResponse response : jobInfo.getNodes()) {
            DiscoveryNode currentNode = response.getNode();
            if (!eligibleNodes.containsKey((Object)currentNode.getId()) || response.getTrainingJobCount() >= 1) continue;
            selectedNode = currentNode;
            if (!Strings.isEmpty((CharSequence)preferredNode) && !selectedNode.getId().equals(preferredNode)) continue;
            return selectedNode;
        }
        return selectedNode;
    }

    protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelRequest, ActionListener<Integer> listener) {
        SearchRequest countRequest = new SearchRequest(new String[]{trainingModelRequest.getTrainingIndex()});
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).trackTotalHits(true);
        countRequest.source(searchSourceBuilder);
        searchSourceBuilder.terminateAfter(0);
        this.client.search(countRequest, ActionListener.wrap(searchResponse -> {
            long trainingVectors = searchResponse.getHits().getTotalHits().value;
            if ((long)trainingModelRequest.getMaximumVectorCount() < trainingVectors) {
                trainingVectors = trainingModelRequest.getMaximumVectorCount();
            }
            listener.onResponse((Object)TrainingJobRouterTransportAction.estimateVectorSetSizeInKB(trainingVectors, trainingModelRequest.getDimension()));
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    public static int estimateVectorSetSizeInKB(long vectorCount, int dimension) {
        return Math.toIntExact((long)(4 * dimension) * vectorCount / (long)KNNConstants.BYTES_PER_KILOBYTES.intValue() + 1L);
    }
}

