/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.index.rankeval;

import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.rankeval.EvalQueryQuality;
import org.elasticsearch.index.rankeval.EvaluationMetric;
import org.elasticsearch.index.rankeval.MetricDetails;
import org.elasticsearch.index.rankeval.RatedDocument;
import org.elasticsearch.index.rankeval.RatedSearchHit;
import org.elasticsearch.search.SearchHit;

public class PrecisionAtK
implements EvaluationMetric {
    public static final String NAME = "precision";
    private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold", new String[0]);
    private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled", new String[0]);
    private static final ParseField K_FIELD = new ParseField("k", new String[0]);
    private static final int DEFAULT_K = 10;
    private final boolean ignoreUnlabeled;
    private final int relevantRatingThreshhold;
    private final int k;
    private static final ConstructingObjectParser<PrecisionAtK, Void> PARSER = new ConstructingObjectParser("precision", args -> {
        Integer threshHold = (Integer)args[0];
        Boolean ignoreUnlabeled = (Boolean)args[1];
        Integer k = (Integer)args[2];
        return new PrecisionAtK(threshHold == null ? 1 : threshHold, ignoreUnlabeled == null ? false : ignoreUnlabeled, k == null ? 10 : k);
    });

    public PrecisionAtK(int threshold, boolean ignoreUnlabeled, int k) {
        if (threshold < 0) {
            throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
        }
        if (k <= 0) {
            throw new IllegalArgumentException("Window size k must be positive.");
        }
        this.relevantRatingThreshhold = threshold;
        this.ignoreUnlabeled = ignoreUnlabeled;
        this.k = k;
    }

    public PrecisionAtK() {
        this(1, false, 10);
    }

    PrecisionAtK(StreamInput in) throws IOException {
        this.relevantRatingThreshhold = in.readVInt();
        this.ignoreUnlabeled = in.readBoolean();
        this.k = in.readVInt();
    }

    int getK() {
        return this.k;
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeVInt(this.relevantRatingThreshhold);
        out.writeBoolean(this.ignoreUnlabeled);
        out.writeVInt(this.k);
    }

    public String getWriteableName() {
        return NAME;
    }

    public int getRelevantRatingThreshold() {
        return this.relevantRatingThreshhold;
    }

    public boolean getIgnoreUnlabeled() {
        return this.ignoreUnlabeled;
    }

    @Override
    public Optional<Integer> forcedSearchSize() {
        return Optional.of(this.k);
    }

    public static PrecisionAtK fromXContent(XContentParser parser) {
        return (PrecisionAtK)PARSER.apply(parser, null);
    }

    @Override
    public EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs) {
        int truePositives = 0;
        int falsePositives = 0;
        List<RatedSearchHit> ratedSearchHits = EvaluationMetric.joinHitsWithRatings(hits, ratedDocs);
        for (RatedSearchHit hit : ratedSearchHits) {
            Optional<Integer> rating = hit.getRating();
            if (rating.isPresent()) {
                if (rating.get() >= this.relevantRatingThreshhold) {
                    ++truePositives;
                    continue;
                }
                ++falsePositives;
                continue;
            }
            if (this.ignoreUnlabeled) continue;
            ++falsePositives;
        }
        double precision = 0.0;
        if (truePositives + falsePositives > 0) {
            precision = (double)truePositives / (double)(truePositives + falsePositives);
        }
        EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, precision);
        evalQueryQuality.setMetricDetails(new Breakdown(truePositives, truePositives + falsePositives));
        evalQueryQuality.addHitsAndRatings(ratedSearchHits);
        return evalQueryQuality;
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.startObject(NAME);
        builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold);
        builder.field(IGNORE_UNLABELED_FIELD.getPreferredName(), this.ignoreUnlabeled);
        builder.field(K_FIELD.getPreferredName(), this.k);
        builder.endObject();
        builder.endObject();
        return builder;
    }

    public final boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || this.getClass() != obj.getClass()) {
            return false;
        }
        PrecisionAtK other = (PrecisionAtK)obj;
        return Objects.equals(this.relevantRatingThreshhold, other.relevantRatingThreshhold) && Objects.equals(this.k, other.k) && Objects.equals(this.ignoreUnlabeled, other.ignoreUnlabeled);
    }

    public final int hashCode() {
        return Objects.hash(this.relevantRatingThreshhold, this.ignoreUnlabeled, this.k);
    }

    static {
        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RELEVANT_RATING_FIELD);
        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), IGNORE_UNLABELED_FIELD);
        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), K_FIELD);
    }

    static class Breakdown
    implements MetricDetails {
        private static final String DOCS_RETRIEVED_FIELD = "docs_retrieved";
        private static final String RELEVANT_DOCS_RETRIEVED_FIELD = "relevant_docs_retrieved";
        private int relevantRetrieved;
        private int retrieved;

        Breakdown(int relevantRetrieved, int retrieved) {
            this.relevantRetrieved = relevantRetrieved;
            this.retrieved = retrieved;
        }

        Breakdown(StreamInput in) throws IOException {
            this.relevantRetrieved = in.readVInt();
            this.retrieved = in.readVInt();
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.field(RELEVANT_DOCS_RETRIEVED_FIELD, this.relevantRetrieved);
            builder.field(DOCS_RETRIEVED_FIELD, this.retrieved);
            return builder;
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeVInt(this.relevantRetrieved);
            out.writeVInt(this.retrieved);
        }

        public String getWriteableName() {
            return PrecisionAtK.NAME;
        }

        int getRelevantRetrieved() {
            return this.relevantRetrieved;
        }

        int getRetrieved() {
            return this.retrieved;
        }

        public final boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || this.getClass() != obj.getClass()) {
                return false;
            }
            Breakdown other = (Breakdown)obj;
            return Objects.equals(this.relevantRetrieved, other.relevantRetrieved) && Objects.equals(this.retrieved, other.retrieved);
        }

        public final int hashCode() {
            return Objects.hash(this.relevantRetrieved, this.retrieved);
        }
    }
}

