/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.httpclient;

import com.google.common.annotations.VisibleForTesting;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import lombok.Generated;
import org.apache.commons.lang3.math.NumberUtils;
import org.apache.http.HttpHost;
import org.apache.http.HttpRequest;
import org.apache.http.HttpResponse;
import org.apache.http.client.RedirectStrategy;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.conn.SchemePortResolver;
import org.apache.http.conn.UnsupportedSchemeException;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.client.LaxRedirectStrategy;
import org.apache.http.impl.conn.DefaultSchemePortResolver;
import org.apache.http.protocol.HttpContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class MLHttpClientFactory {
    @Generated
    private static final Logger log = LogManager.getLogger(MLHttpClientFactory.class);

    public static CloseableHttpClient getCloseableHttpClient(Integer connectionTimeout, Integer readTimeout, Integer maxConnections) {
        return MLHttpClientFactory.createHttpClient(connectionTimeout, readTimeout, maxConnections);
    }

    private static CloseableHttpClient createHttpClient(Integer connectionTimeout, Integer readTimeout, Integer maxConnections) {
        HttpClientBuilder builder = HttpClientBuilder.create();
        builder.setSchemePortResolver((SchemePortResolver)new DefaultSchemePortResolver(){

            public int resolve(HttpHost host) throws UnsupportedSchemeException {
                MLHttpClientFactory.validateSchemaAndPort(host);
                return super.resolve(host);
            }
        });
        builder.setDnsResolver(MLHttpClientFactory::validateIp);
        builder.setRedirectStrategy((RedirectStrategy)new LaxRedirectStrategy(){

            public boolean isRedirected(HttpRequest request, HttpResponse response, HttpContext context) {
                return false;
            }
        });
        builder.setMaxConnTotal(maxConnections.intValue());
        builder.setMaxConnPerRoute(maxConnections.intValue());
        RequestConfig requestConfig = RequestConfig.custom().setConnectTimeout(connectionTimeout.intValue()).setSocketTimeout(readTimeout.intValue()).build();
        builder.setDefaultRequestConfig(requestConfig);
        return builder.build();
    }

    @VisibleForTesting
    protected static void validateSchemaAndPort(HttpHost host) {
        String scheme = host.getSchemeName();
        if ("http".equalsIgnoreCase(scheme) || "https".equalsIgnoreCase(scheme)) {
            int port;
            String[] hostNamePort = host.getHostName().split(":");
            if (hostNamePort.length > 1 && NumberUtils.isDigits((String)hostNamePort[1]) && ((port = Integer.parseInt(hostNamePort[1])) < 0 || port > 65536)) {
                log.error("Remote inference port out of range: " + port);
                throw new IllegalArgumentException("Port out of range: " + port);
            }
        } else {
            log.error("Remote inference scheme not supported: " + scheme);
            throw new IllegalArgumentException("Unsupported scheme: " + scheme);
        }
    }

    protected static InetAddress[] validateIp(String hostName) throws UnknownHostException {
        InetAddress[] addresses = InetAddress.getAllByName(hostName);
        if (MLHttpClientFactory.hasPrivateIpAddress(addresses)) {
            log.error("Remote inference host name has private ip address: " + hostName);
            throw new IllegalArgumentException(hostName);
        }
        return addresses;
    }

    private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) {
        for (InetAddress ip : ipAddress) {
            if (!(ip instanceof Inet4Address)) continue;
            byte[] bytes = ip.getAddress();
            if (bytes.length != 4) {
                return true;
            }
            int firstOctets = bytes[0] & 0xFF;
            int firstInOctal = MLHttpClientFactory.parseWithOctal(String.valueOf(firstOctets));
            int firstInHex = Integer.parseInt(String.valueOf(firstOctets), 16);
            if (firstInOctal == 127 || firstInHex == 127) {
                return bytes[1] == 0 && bytes[2] == 0 && bytes[3] == 1;
            }
            if (firstInOctal == 10 || firstInHex == 10) {
                return true;
            }
            if (firstInOctal == 172 || firstInHex == 172) {
                int secondOctets = bytes[1] & 0xFF;
                int secondInOctal = MLHttpClientFactory.parseWithOctal(String.valueOf(secondOctets));
                int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16);
                return secondInOctal >= 16 && secondInOctal <= 32 || secondInHex >= 16 && secondInHex <= 32;
            }
            if (firstInOctal != 192 && firstInHex != 192) continue;
            int secondOctets = bytes[1] & 0xFF;
            int secondInOctal = MLHttpClientFactory.parseWithOctal(String.valueOf(secondOctets));
            int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16);
            return secondInOctal == 168 || secondInHex == 168;
        }
        return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress());
    }

    private static int parseWithOctal(String input) {
        try {
            return Integer.parseInt(input, 8);
        }
        catch (NumberFormatException e) {
            return Integer.parseInt(input);
        }
    }
}

