/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.cluster.routing;

import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.OpenSearchException;
import org.opensearch.action.search.SearchShardIterator;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.metadata.WeightedRoutingMetadata;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.routing.ShardRouting;
import org.opensearch.cluster.routing.ShardsIterator;
import org.opensearch.cluster.routing.WeightedRouting;
import org.opensearch.index.shard.ShardId;
import org.opensearch.rest.RestStatus;
import org.opensearch.search.SearchShardTarget;

public class FailAwareWeightedRouting {
    public static final FailAwareWeightedRouting INSTANCE = new FailAwareWeightedRouting();
    private static final Logger logger = LogManager.getLogger(FailAwareWeightedRouting.class);
    private static final List<RestStatus> internalErrorRestStatusList = List.of(RestStatus.INTERNAL_SERVER_ERROR, RestStatus.BAD_GATEWAY, RestStatus.SERVICE_UNAVAILABLE, RestStatus.GATEWAY_TIMEOUT);

    public static FailAwareWeightedRouting getInstance() {
        return INSTANCE;
    }

    private boolean isInternalFailure(Exception exception) {
        if (exception instanceof OpenSearchException) {
            return internalErrorRestStatusList.contains((Object)((OpenSearchException)exception).status());
        }
        return false;
    }

    private boolean isWeighedAway(String nodeId, ClusterState clusterState) {
        WeightedRouting weightedRouting;
        DiscoveryNode node = clusterState.nodes().get(nodeId);
        WeightedRoutingMetadata weightedRoutingMetadata = clusterState.metadata().weightedRoutingMetadata();
        if (weightedRoutingMetadata != null && (weightedRouting = weightedRoutingMetadata.getWeightedRouting()) != null && weightedRouting.isSet()) {
            Stream<String> keys = weightedRouting.weights().entrySet().stream().filter(entry -> ((Double)entry.getValue()).intValue() == 0).map(Map.Entry::getKey);
            for (Object key : keys.toArray()) {
                if (!node.getAttributes().get(weightedRouting.attributeName()).equals(key.toString())) continue;
                return true;
            }
        }
        return false;
    }

    public SearchShardTarget findNext(SearchShardIterator shardIt, ClusterState clusterState, Exception exception) {
        SearchShardTarget next = shardIt.nextOrNull();
        while (next != null && this.isWeighedAway(next.getNodeId(), clusterState)) {
            SearchShardTarget nextShard = next;
            if (this.canFailOpen(nextShard.getShardId(), exception, clusterState)) {
                logger.info(() -> new ParameterizedMessage("{}: Fail open executed due to exception", (Object)nextShard.getShardId()), (Throwable)exception);
                break;
            }
            next = shardIt.nextOrNull();
        }
        return next;
    }

    public ShardRouting findNext(ShardsIterator shardsIt, ClusterState clusterState, Exception exception) {
        ShardRouting next = shardsIt.nextOrNull();
        while (next != null && this.isWeighedAway(next.currentNodeId(), clusterState)) {
            ShardRouting nextShard = next;
            if (this.canFailOpen(nextShard.shardId(), exception, clusterState)) {
                logger.info(() -> new ParameterizedMessage("{}: Fail open executed due to exception", (Object)nextShard.shardId()), (Throwable)exception);
                break;
            }
            next = shardsIt.nextOrNull();
        }
        return next;
    }

    private boolean canFailOpen(ShardId shardId, Exception exception, ClusterState clusterState) {
        return this.isInternalFailure(exception) || this.hasInActiveShardCopies(clusterState, shardId);
    }

    private boolean hasInActiveShardCopies(ClusterState clusterState, ShardId shardId) {
        List<ShardRouting> shards = clusterState.routingTable().shardRoutingTable(shardId).shards();
        for (ShardRouting shardRouting : shards) {
            if (shardRouting.active()) continue;
            return true;
        }
        return false;
    }
}

