/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.legacy.executor.format;

import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.expr.SQLCaseExpr;
import com.alibaba.druid.sql.ast.expr.SQLCastExpr;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsRequest;
import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsResponse;
import org.opensearch.action.search.ClearScrollResponse;
import org.opensearch.client.Client;
import org.opensearch.common.Strings;
import org.opensearch.common.document.DocumentField;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.Aggregation;
import org.opensearch.search.aggregations.Aggregations;
import org.opensearch.search.aggregations.bucket.terms.Terms;
import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.opensearch.search.aggregations.metrics.Percentile;
import org.opensearch.search.aggregations.metrics.Percentiles;
import org.opensearch.sql.legacy.cursor.Cursor;
import org.opensearch.sql.legacy.cursor.DefaultCursor;
import org.opensearch.sql.legacy.domain.ColumnTypeProvider;
import org.opensearch.sql.legacy.domain.Field;
import org.opensearch.sql.legacy.domain.JoinSelect;
import org.opensearch.sql.legacy.domain.MethodField;
import org.opensearch.sql.legacy.domain.Query;
import org.opensearch.sql.legacy.domain.Select;
import org.opensearch.sql.legacy.domain.TableOnJoinSelect;
import org.opensearch.sql.legacy.esdomain.mapping.FieldMapping;
import org.opensearch.sql.legacy.exception.SqlFeatureNotImplementedException;
import org.opensearch.sql.legacy.executor.Format;
import org.opensearch.sql.legacy.executor.format.DataRows;
import org.opensearch.sql.legacy.executor.format.DateFieldFormatter;
import org.opensearch.sql.legacy.executor.format.ResultSet;
import org.opensearch.sql.legacy.executor.format.Schema;
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.sql.legacy.utils.SQLFunctions;

public class SelectResultSet
extends ResultSet {
    private static final Logger LOG = LogManager.getLogger(SelectResultSet.class);
    public static final String SCORE = "_score";
    private final String formatType;
    private Query query;
    private Object queryResult;
    private boolean selectAll;
    private String indexName;
    private List<Schema.Column> columns = new ArrayList<Schema.Column>();
    private ColumnTypeProvider outputColumnType;
    private List<String> head;
    private long size;
    private long totalHits;
    private long internalTotalHits;
    private List<DataRows.Row> rows;
    private Cursor cursor;
    private DateFieldFormatter dateFieldFormatter;
    private Map<String, String> fieldAliasMap = new HashMap<String, String>();

    public SelectResultSet(Client client, Query query, Object queryResult, ColumnTypeProvider outputColumnType, String formatType, Cursor cursor) {
        this.client = client;
        this.query = query;
        this.queryResult = queryResult;
        this.selectAll = false;
        this.formatType = formatType;
        this.outputColumnType = outputColumnType;
        this.cursor = cursor;
        if (this.isJoinQuery()) {
            JoinSelect joinQuery = (JoinSelect)query;
            this.loadFromEsState(joinQuery.getFirstTable());
            this.loadFromEsState(joinQuery.getSecondTable());
        } else {
            this.loadFromEsState(query);
        }
        this.schema = new Schema(this.indexName, this.columns);
        this.head = this.schema.getHeaders();
        this.dateFieldFormatter = new DateFieldFormatter(this.indexName, this.columns, this.fieldAliasMap);
        this.extractData();
        this.populateCursor();
        this.dataRows = new DataRows(this.size, this.totalHits, this.rows);
    }

    public SelectResultSet(Client client, Object queryResult, String formatType, Cursor cursor) {
        this.cursor = cursor;
        this.client = client;
        this.queryResult = queryResult;
        this.selectAll = false;
        this.formatType = formatType;
        this.populateResultSetFromCursor(cursor);
    }

    public String indexName() {
        return this.indexName;
    }

    public Map<String, String> fieldAliasMap() {
        return Collections.unmodifiableMap(this.fieldAliasMap);
    }

    public void populateResultSetFromCursor(Cursor cursor) {
        switch (cursor.getType()) {
            case DEFAULT: {
                this.populateResultSetFromDefaultCursor((DefaultCursor)cursor);
            }
        }
    }

    private void populateResultSetFromDefaultCursor(DefaultCursor cursor) {
        this.columns = cursor.getColumns();
        this.schema = new Schema(this.columns);
        this.head = this.schema.getHeaders();
        this.dateFieldFormatter = new DateFieldFormatter(cursor.getIndexPattern(), this.columns, cursor.getFieldAliasMap());
        this.extractData();
        this.dataRows = new DataRows(this.size, this.totalHits, this.rows);
    }

    private void loadFromEsState(Query query) {
        String indexName = this.fetchIndexName(query);
        String[] fieldNames = this.fetchFieldsAsArray(query);
        this.selectAll = this.isSimpleQuerySelectAll(query) || this.isJoinQuerySelectAll(query, fieldNames);
        GetFieldMappingsRequest request = new GetFieldMappingsRequest().indices(new String[]{indexName}).fields(this.selectAllFieldsIfEmpty(fieldNames)).local(true);
        GetFieldMappingsResponse response = (GetFieldMappingsResponse)this.client.admin().indices().getFieldMappings(request).actionGet();
        Map mappings = response.mappings();
        if (mappings.isEmpty() || !mappings.containsKey(indexName)) {
            throw new IllegalArgumentException(String.format("Index type %s does not exist", query.getFrom()));
        }
        Map typeMappings = (Map)mappings.get(indexName);
        this.indexName = this.indexName == null ? indexName : this.indexName + "|" + indexName;
        this.columns.addAll(this.renameColumnWithTableAlias(query, this.populateColumns(query, fieldNames, typeMappings)));
    }

    private List<Schema.Column> renameColumnWithTableAlias(Query query, List<Schema.Column> columns) {
        List<Schema.Column> renamedCols;
        if (query instanceof TableOnJoinSelect && !Strings.isNullOrEmpty((String)((TableOnJoinSelect)query).getAlias())) {
            TableOnJoinSelect joinQuery = (TableOnJoinSelect)query;
            renamedCols = new ArrayList<Schema.Column>();
            for (Schema.Column column : columns) {
                renamedCols.add(new Schema.Column(joinQuery.getAlias() + "." + column.getName(), column.getAlias(), Schema.Type.valueOf(column.getType().toUpperCase()), true));
            }
        } else {
            renamedCols = columns;
        }
        return renamedCols;
    }

    private boolean isSelectAll() {
        return this.selectAll;
    }

    private boolean isSimpleQuerySelectAll(Query query) {
        return query instanceof Select && ((Select)query).isSelectAll();
    }

    private boolean isJoinQuerySelectAll(Query query, String[] fieldNames) {
        return fieldNames.length == 0 && !this.fieldsSelectedOnAnotherTable(query);
    }

    private boolean fieldsSelectedOnAnotherTable(Query query) {
        if (this.isJoinQuery()) {
            TableOnJoinSelect otherTable = this.getOtherTable(query);
            return otherTable.getSelectedFields().size() > 0;
        }
        return false;
    }

    private TableOnJoinSelect getOtherTable(Query currJoinSelect) {
        JoinSelect joinQuery = (JoinSelect)this.query;
        if (joinQuery.getFirstTable() == currJoinSelect) {
            return joinQuery.getSecondTable();
        }
        return joinQuery.getFirstTable();
    }

    private boolean containsWildcard(Query query) {
        for (Field field : this.fetchFields(query)) {
            if (field instanceof MethodField || !field.getName().contains("*")) continue;
            return true;
        }
        return false;
    }

    private String fetchIndexName(Query query) {
        return query.getFrom().get(0).getIndex();
    }

    private List<Field> fetchFields(Query query) {
        Select select = (Select)query;
        if (this.queryResult instanceof Aggregations) {
            ArrayList<Field> groupByFields = select.getGroupBys().isEmpty() ? new ArrayList<Field>() : select.getGroupBys().get(0);
            for (Field selectField : select.getFields()) {
                if (selectField instanceof MethodField && !selectField.isScriptField()) {
                    groupByFields.add(selectField);
                    continue;
                }
                if (!selectField.isScriptField() || !selectField.getAlias().equals(((Field)groupByFields.get(0)).getName())) continue;
                return select.getFields();
            }
            return groupByFields;
        }
        if (query instanceof TableOnJoinSelect) {
            return ((TableOnJoinSelect)query).getSelectedFields();
        }
        return select.getFields();
    }

    private String[] fetchFieldsAsArray(Query query) {
        List<Field> fields = this.fetchFields(query);
        return (String[])fields.stream().map(this::getFieldName).toArray(String[]::new);
    }

    private String getFieldName(Field field) {
        if (field instanceof MethodField) {
            return field.getAlias();
        }
        return field.getName();
    }

    private Map<String, Field> fetchFieldMap(Query query) {
        HashMap<String, Field> fieldMap = new HashMap<String, Field>();
        for (Field field : this.fetchFields(query)) {
            fieldMap.put(this.getFieldName(field), field);
        }
        return fieldMap;
    }

    private String[] selectAllFieldsIfEmpty(String[] fields) {
        if (this.isSelectAll()) {
            return new String[]{"*"};
        }
        return fields;
    }

    private String[] emptyArrayIfNull(String typeName) {
        if (typeName != null) {
            return new String[]{typeName};
        }
        return Strings.EMPTY_ARRAY;
    }

    private Schema.Type fetchMethodReturnType(int fieldIndex, MethodField field) {
        switch (field.getName().toLowerCase()) {
            case "count": {
                return Schema.Type.LONG;
            }
            case "sum": 
            case "avg": 
            case "min": 
            case "max": 
            case "percentiles": {
                return Schema.Type.DOUBLE;
            }
            case "script": {
                if (field.getExpression() instanceof SQLCaseExpr) {
                    return Schema.Type.TEXT;
                }
                Schema.Type resolvedType = this.outputColumnType.get(fieldIndex);
                return SQLFunctions.getScriptFunctionReturnType(field, resolvedType);
            }
        }
        throw new UnsupportedOperationException(String.format("The following method is not supported in Schema: %s", field.getName()));
    }

    private List<Schema.Column> populateColumns(Query query, String[] fieldNames, Map<String, GetFieldMappingsResponse.FieldMappingMetadata> typeMappings) {
        List<String> fieldNameList = this.isSelectAll() || this.containsWildcard(query) ? new ArrayList<String>(typeMappings.keySet()) : Arrays.asList(fieldNames);
        Map<String, Field> fieldMap = this.fetchFieldMap(query);
        ArrayList<Schema.Column> columns = new ArrayList<Schema.Column>();
        for (String fieldName : fieldNameList) {
            if (fieldName.equals(SCORE)) {
                columns.add(new Schema.Column(fieldName, this.fetchAlias(fieldName, fieldMap), Schema.Type.FLOAT));
                continue;
            }
            if (fieldMap.get(fieldName) instanceof MethodField) {
                MethodField methodField = (MethodField)fieldMap.get(fieldName);
                int fieldIndex = fieldNameList.indexOf(fieldName);
                SQLExpr expr = methodField.getExpression();
                if (expr instanceof SQLCastExpr) {
                    SQLIdentifierExpr castFieldIdentifier = (SQLIdentifierExpr)((SQLCastExpr)expr).getExpr();
                    this.fieldAliasMap.put(methodField.getAlias(), castFieldIdentifier.getName());
                }
                columns.add(new Schema.Column(methodField.getAlias(), null, this.fetchMethodReturnType(fieldIndex, methodField)));
                continue;
            }
            FieldMapping field = new FieldMapping(fieldName, typeMappings, fieldMap);
            if (field.isMetaField() || field.isMultiField() && !field.isSpecified() || field.isPropertyField() && !field.isSpecified() && !field.isWildcardSpecified()) continue;
            String type = field.type().toUpperCase();
            if (Schema.hasType(type)) {
                boolean isGroupKey = false;
                Select select = (Select)query;
                if (null != select.getGroupBys() && !select.getGroupBys().isEmpty() && select.getGroupBys().get(0).contains(fieldMap.get(fieldName))) {
                    isGroupKey = true;
                }
                columns.add(new Schema.Column(fieldName, this.fetchAlias(fieldName, fieldMap), Schema.Type.valueOf(type), isGroupKey));
                continue;
            }
            if (this.isSelectAll()) continue;
            throw new IllegalArgumentException(String.format("%s fieldName types are currently not supported.", type));
        }
        if (this.isSelectAllOnly(query)) {
            this.populateAllNestedFields(columns, fieldNameList);
        }
        return columns;
    }

    private boolean isSelectAllOnly(Query query) {
        return this.isSelectAll() && this.fetchFields(query).isEmpty();
    }

    private void populateAllNestedFields(List<Schema.Column> columns, List<String> fields) {
        Set nestedFieldPaths = fields.stream().map(FieldMapping::new).filter(FieldMapping::isPropertyField).filter(f -> !f.isMultiField()).map(FieldMapping::path).collect(Collectors.toSet());
        for (String nestedFieldPath : nestedFieldPaths) {
            columns.add(new Schema.Column(nestedFieldPath, "", Schema.Type.TEXT));
        }
    }

    private String fetchAlias(String fieldName, Map<String, Field> fieldMap) {
        if (fieldMap.containsKey(fieldName)) {
            return fieldMap.get(fieldName).getAlias();
        }
        return null;
    }

    private void extractData() {
        if (this.queryResult instanceof SearchHits) {
            SearchHits searchHits = (SearchHits)this.queryResult;
            this.rows = this.populateRows(searchHits);
            this.size = this.rows.size();
            this.internalTotalHits = Optional.ofNullable(searchHits.getTotalHits()).map(th -> th.value).orElse(0L);
            this.totalHits = Math.max(this.size, this.internalTotalHits);
        } else if (this.queryResult instanceof Aggregations) {
            Aggregations aggregations = (Aggregations)this.queryResult;
            this.rows = this.populateRows(aggregations);
            this.internalTotalHits = this.size = (long)this.rows.size();
            this.totalHits = this.size;
        }
    }

    private void populateCursor() {
        switch (this.cursor.getType()) {
            case DEFAULT: {
                this.populateDefaultCursor((DefaultCursor)this.cursor);
            }
        }
    }

    private void populateDefaultCursor(DefaultCursor cursor) {
        Integer limit = cursor.getLimit();
        long rowsLeft = this.rowsLeft(cursor.getFetchSize(), cursor.getLimit());
        if (rowsLeft <= 0L) {
            String scrollId = cursor.getScrollId();
            ClearScrollResponse clearScrollResponse = (ClearScrollResponse)this.client.prepareClearScroll().addScrollId(scrollId).get();
            if (!clearScrollResponse.isSucceeded()) {
                Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment();
                LOG.error("Error closing the cursor context {} ", (Object)scrollId);
            }
            return;
        }
        cursor.setRowsLeft(rowsLeft);
        cursor.setIndexPattern(this.indexName);
        cursor.setFieldAliasMap(this.fieldAliasMap());
        cursor.setColumns(this.columns);
        this.totalHits = limit != null && (long)limit.intValue() < this.internalTotalHits ? (long)limit.intValue() : this.internalTotalHits;
    }

    private long rowsLeft(Integer fetchSize, Integer limit) {
        long rowsLeft = 0L;
        long totalHits = this.internalTotalHits;
        rowsLeft = limit != null && (long)limit.intValue() < totalHits ? (long)(limit - fetchSize) : totalHits - (long)fetchSize.intValue();
        return rowsLeft;
    }

    private List<DataRows.Row> populateRows(SearchHits searchHits) {
        ArrayList<DataRows.Row> rows = new ArrayList<DataRows.Row>();
        HashSet<String> newKeys = new HashSet<String>(this.head);
        for (SearchHit hit : searchHits) {
            List<DataRows.Row> result;
            Map<String, Object> rowSource = hit.getSourceAsMap();
            if (!this.isJoinQuery()) {
                rowSource = this.flatRow(this.head, rowSource);
                rowSource.put(SCORE, Float.valueOf(hit.getScore()));
                for (Map.Entry field : hit.getFields().entrySet()) {
                    rowSource.put((String)field.getKey(), ((DocumentField)field.getValue()).getValue());
                }
                if (this.formatType.equalsIgnoreCase(Format.JDBC.getFormatName())) {
                    this.dateFieldFormatter.applyJDBCDateFormat(rowSource);
                }
                result = this.flatNestedField(newKeys, rowSource, hit.getInnerHits());
            } else {
                if (this.formatType.equalsIgnoreCase(Format.JDBC.getFormatName())) {
                    this.dateFieldFormatter.applyJDBCDateFormat(rowSource);
                }
                result = new ArrayList<DataRows.Row>();
                result.add(new DataRows.Row(rowSource));
            }
            rows.addAll(result);
        }
        return rows;
    }

    private List<DataRows.Row> populateRows(Aggregations aggregations) {
        ArrayList<DataRows.Row> rows = new ArrayList<DataRows.Row>();
        List aggs = aggregations.asList();
        if (this.hasTermAggregations(aggs)) {
            Terms terms = (Terms)aggs.get(0);
            String field = terms.getName();
            for (Terms.Bucket bucket : terms.getBuckets()) {
                ArrayList<DataRows.Row> aggRows = new ArrayList<DataRows.Row>();
                this.getAggsData(bucket, aggRows, this.addMap(field, bucket.getKey()));
                rows.addAll(aggRows);
            }
        } else {
            rows.add(new DataRows.Row(this.addNumericAggregation(aggs, new HashMap<String, Object>())));
        }
        return rows;
    }

    private void getAggsData(Terms.Bucket bucket, List<DataRows.Row> aggRows, Map<String, Object> data) {
        List aggs = bucket.getAggregations().asList();
        if (this.hasTermAggregations(aggs)) {
            Terms terms = (Terms)aggs.get(0);
            String field = terms.getName();
            for (Terms.Bucket innerBucket : terms.getBuckets()) {
                data.put(field, innerBucket.getKey());
                this.getAggsData(innerBucket, aggRows, data);
                data.remove(field);
            }
        } else {
            data = this.addNumericAggregation(aggs, data);
            aggRows.add(new DataRows.Row(new HashMap<String, Object>(data)));
        }
    }

    private boolean hasTermAggregations(List<Aggregation> aggs) {
        return !aggs.isEmpty() && aggs.get(0) instanceof Terms;
    }

    private Map<String, Object> addNumericAggregation(List<Aggregation> aggs, Map<String, Object> data) {
        for (Aggregation aggregation : aggs) {
            if (aggregation instanceof NumericMetricsAggregation.SingleValue) {
                NumericMetricsAggregation.SingleValue singleValueAggregation = (NumericMetricsAggregation.SingleValue)aggregation;
                data.put(singleValueAggregation.getName(), !Double.isInfinite(singleValueAggregation.value()) ? singleValueAggregation.getValueAsString() : "null");
                continue;
            }
            if (aggregation instanceof Percentiles) {
                Percentiles percentiles = (Percentiles)aggregation;
                data.put(percentiles.getName(), StreamSupport.stream(percentiles.spliterator(), false).collect(Collectors.toMap(Percentile::getPercent, Percentile::getValue, (v1, v2) -> {
                    throw new IllegalArgumentException(String.format("Duplicate key for values %s and %s", v1, v2));
                }, TreeMap::new)));
                continue;
            }
            throw new SqlFeatureNotImplementedException("Aggregation type " + aggregation.getType() + " is not yet implemented");
        }
        return data;
    }

    private Map<String, Object> flatRow(List<String> keys, Map<String, Object> row) {
        HashMap<String, Object> flattenedRow = new HashMap<String, Object>();
        for (String key : keys) {
            String[] splitKeys = key.split("\\.");
            boolean found = true;
            Object currentObj = row;
            for (String splitKey : splitKeys) {
                if (!(currentObj instanceof Map)) {
                    found = false;
                    break;
                }
                Object currentMap = currentObj;
                if (!currentMap.containsKey(splitKey)) {
                    found = false;
                    break;
                }
                currentObj = currentMap.get(splitKey);
            }
            if (!found) continue;
            flattenedRow.put(key, currentObj);
        }
        return flattenedRow;
    }

    private List<DataRows.Row> flatNestedField(Set<String> newKeys, Map<String, Object> row, Map<String, SearchHits> innerHits) {
        List<DataRows.Row> result = new ArrayList<DataRows.Row>();
        result.add(new DataRows.Row(row));
        if (innerHits == null) {
            return result;
        }
        for (String colName : innerHits.keySet()) {
            SearchHit[] colValue = innerHits.get(colName).getHits();
            this.doFlatNestedFieldName(colName, colValue, newKeys);
            result = this.doFlatNestedFieldValue(colName, colValue, result);
        }
        return result;
    }

    private void doFlatNestedFieldName(String colName, SearchHit[] colValue, Set<String> keys) {
        Map innerRow = colValue[0].getSourceAsMap();
        for (String field : innerRow.keySet()) {
            String innerName = colName + "." + field;
            keys.add(innerName);
        }
        keys.remove(colName);
    }

    private List<DataRows.Row> doFlatNestedFieldValue(String colName, SearchHit[] colValue, List<DataRows.Row> rows) {
        ArrayList<DataRows.Row> result = new ArrayList<DataRows.Row>();
        for (DataRows.Row row : rows) {
            for (SearchHit hit : colValue) {
                Map innerRow = hit.getSourceAsMap();
                HashMap<String, Object> copy = new HashMap<String, Object>();
                for (String field : row.getContents().keySet()) {
                    copy.put(field, row.getData(field));
                }
                for (String field : innerRow.keySet()) {
                    copy.put(colName + "." + field, innerRow.get(field));
                }
                copy.remove(colName);
                result.add(new DataRows.Row(copy));
            }
        }
        return result;
    }

    private Map<String, Object> addMap(String field, Object term) {
        HashMap<String, Object> data = new HashMap<String, Object>();
        data.put(field, term);
        return data;
    }

    private boolean isJoinQuery() {
        return this.query instanceof JoinSelect;
    }
}

