/*
 * Decompiled with CFR 0.152.
 */
package com.hankcs.hanlp.dependency.perceptron.learning;

import com.hankcs.hanlp.dependency.perceptron.structures.CompactArray;
import com.hankcs.hanlp.dependency.perceptron.structures.ParserModel;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.Action;
import java.util.HashMap;

public class AveragedPerceptron {
    public HashMap<Object, Float>[] shiftFeatureWeights;
    public HashMap<Object, Float>[] reduceFeatureWeights;
    public HashMap<Object, CompactArray>[] leftArcFeatureWeights;
    public HashMap<Object, CompactArray>[] rightArcFeatureWeights;
    public int iteration;
    public int dependencySize;
    public HashMap<Object, Float>[] shiftFeatureAveragedWeights;
    public HashMap<Object, Float>[] reduceFeatureAveragedWeights;
    public HashMap<Object, CompactArray>[] leftArcFeatureAveragedWeights;
    public HashMap<Object, CompactArray>[] rightArcFeatureAveragedWeights;

    public AveragedPerceptron(int featSize, int dependencySize) {
        this.shiftFeatureWeights = new HashMap[featSize];
        this.reduceFeatureWeights = new HashMap[featSize];
        this.leftArcFeatureWeights = new HashMap[featSize];
        this.rightArcFeatureWeights = new HashMap[featSize];
        this.shiftFeatureAveragedWeights = new HashMap[featSize];
        this.reduceFeatureAveragedWeights = new HashMap[featSize];
        this.leftArcFeatureAveragedWeights = new HashMap[featSize];
        this.rightArcFeatureAveragedWeights = new HashMap[featSize];
        for (int i = 0; i < featSize; ++i) {
            this.shiftFeatureWeights[i] = new HashMap();
            this.reduceFeatureWeights[i] = new HashMap();
            this.leftArcFeatureWeights[i] = new HashMap();
            this.rightArcFeatureWeights[i] = new HashMap();
            this.shiftFeatureAveragedWeights[i] = new HashMap();
            this.reduceFeatureAveragedWeights[i] = new HashMap();
            this.leftArcFeatureAveragedWeights[i] = new HashMap();
            this.rightArcFeatureAveragedWeights[i] = new HashMap();
        }
        this.iteration = 1;
        this.dependencySize = dependencySize;
    }

    private AveragedPerceptron(HashMap<Object, Float>[] shiftFeatureAveragedWeights, HashMap<Object, Float>[] reduceFeatureAveragedWeights, HashMap<Object, CompactArray>[] leftArcFeatureAveragedWeights, HashMap<Object, CompactArray>[] rightArcFeatureAveragedWeights, int dependencySize) {
        this.shiftFeatureAveragedWeights = shiftFeatureAveragedWeights;
        this.reduceFeatureAveragedWeights = reduceFeatureAveragedWeights;
        this.leftArcFeatureAveragedWeights = leftArcFeatureAveragedWeights;
        this.rightArcFeatureAveragedWeights = rightArcFeatureAveragedWeights;
        this.dependencySize = dependencySize;
    }

    public AveragedPerceptron(ParserModel parserModel) {
        this(parserModel.shiftFeatureAveragedWeights, parserModel.reduceFeatureAveragedWeights, parserModel.leftArcFeatureAveragedWeights, parserModel.rightArcFeatureAveragedWeights, parserModel.dependencySize);
    }

    public float changeWeight(Action actionType, int slotNum, Object featureName, int labelIndex, float change) {
        if (featureName == null) {
            return 0.0f;
        }
        if (actionType == Action.Shift) {
            if (!this.shiftFeatureWeights[slotNum].containsKey(featureName)) {
                this.shiftFeatureWeights[slotNum].put(featureName, Float.valueOf(change));
            } else {
                this.shiftFeatureWeights[slotNum].put(featureName, Float.valueOf(this.shiftFeatureWeights[slotNum].get(featureName).floatValue() + change));
            }
            if (!this.shiftFeatureAveragedWeights[slotNum].containsKey(featureName)) {
                this.shiftFeatureAveragedWeights[slotNum].put(featureName, Float.valueOf((float)this.iteration * change));
            } else {
                this.shiftFeatureAveragedWeights[slotNum].put(featureName, Float.valueOf(this.shiftFeatureAveragedWeights[slotNum].get(featureName).floatValue() + (float)this.iteration * change));
            }
        } else if (actionType == Action.Reduce) {
            if (!this.reduceFeatureWeights[slotNum].containsKey(featureName)) {
                this.reduceFeatureWeights[slotNum].put(featureName, Float.valueOf(change));
            } else {
                this.reduceFeatureWeights[slotNum].put(featureName, Float.valueOf(this.reduceFeatureWeights[slotNum].get(featureName).floatValue() + change));
            }
            if (!this.reduceFeatureAveragedWeights[slotNum].containsKey(featureName)) {
                this.reduceFeatureAveragedWeights[slotNum].put(featureName, Float.valueOf((float)this.iteration * change));
            } else {
                this.reduceFeatureAveragedWeights[slotNum].put(featureName, Float.valueOf(this.reduceFeatureAveragedWeights[slotNum].get(featureName).floatValue() + (float)this.iteration * change));
            }
        } else if (actionType == Action.RightArc) {
            this.changeFeatureWeight(this.rightArcFeatureWeights[slotNum], this.rightArcFeatureAveragedWeights[slotNum], featureName, labelIndex, change, this.dependencySize);
        } else if (actionType == Action.LeftArc) {
            this.changeFeatureWeight(this.leftArcFeatureWeights[slotNum], this.leftArcFeatureAveragedWeights[slotNum], featureName, labelIndex, change, this.dependencySize);
        }
        return change;
    }

    public void changeFeatureWeight(HashMap<Object, CompactArray> map, HashMap<Object, CompactArray> aMap, Object featureName, int labelIndex, float change, int size) {
        CompactArray values = map.get(featureName);
        if (values != null) {
            values.set(labelIndex, change);
            CompactArray aValues = aMap.get(featureName);
            aValues.set(labelIndex, (float)this.iteration * change);
        } else {
            float[] val = new float[]{change};
            values = new CompactArray(labelIndex, val);
            map.put(featureName, values);
            float[] aVal = new float[]{(float)this.iteration * change};
            CompactArray aValues = new CompactArray(labelIndex, aVal);
            aMap.put(featureName, aValues);
        }
    }

    public void incrementIteration() {
        ++this.iteration;
    }

    public float shiftScore(Object[] features, boolean decode) {
        float score = 0.0f;
        HashMap<Object, Float>[] map = decode ? this.shiftFeatureAveragedWeights : this.shiftFeatureWeights;
        for (int i = 0; i < features.length; ++i) {
            Float weight;
            if (features[i] == null || i >= 26 && i < 32 || (weight = map[i].get(features[i])) == null) continue;
            score += weight.floatValue();
        }
        return score;
    }

    public float reduceScore(Object[] features, boolean decode) {
        float score = 0.0f;
        HashMap<Object, Float>[] map = decode ? this.reduceFeatureAveragedWeights : this.reduceFeatureWeights;
        for (int i = 0; i < features.length; ++i) {
            Float values;
            if (features[i] == null || i >= 26 && i < 32 || (values = map[i].get(features[i])) == null) continue;
            score += values.floatValue();
        }
        return score;
    }

    public float[] leftArcScores(Object[] features, boolean decode) {
        float[] scores = new float[this.dependencySize];
        HashMap<Object, CompactArray>[] map = decode ? this.leftArcFeatureAveragedWeights : this.leftArcFeatureWeights;
        for (int i = 0; i < features.length; ++i) {
            CompactArray values;
            if (features[i] == null || (values = map[i].get(features[i])) == null) continue;
            int offset = values.getOffset();
            float[] weightVector = values.getArray();
            for (int d = offset; d < offset + weightVector.length; ++d) {
                int n = d;
                scores[n] = scores[n] + weightVector[d - offset];
            }
        }
        return scores;
    }

    public float[] rightArcScores(Object[] features, boolean decode) {
        float[] scores = new float[this.dependencySize];
        HashMap<Object, CompactArray>[] map = decode ? this.rightArcFeatureAveragedWeights : this.rightArcFeatureWeights;
        for (int i = 0; i < features.length; ++i) {
            CompactArray values;
            if (features[i] == null || (values = map[i].get(features[i])) == null) continue;
            int offset = values.getOffset();
            float[] weightVector = values.getArray();
            for (int d = offset; d < offset + weightVector.length; ++d) {
                int n = d;
                scores[n] = scores[n] + weightVector[d - offset];
            }
        }
        return scores;
    }

    public int featureSize() {
        return this.shiftFeatureAveragedWeights.length;
    }

    public int raSize() {
        int size = 0;
        for (int i = 0; i < this.leftArcFeatureAveragedWeights.length; ++i) {
            for (Object feat : this.rightArcFeatureAveragedWeights[i].keySet()) {
                size += this.rightArcFeatureAveragedWeights[i].get(feat).length();
            }
        }
        return size;
    }

    public int effectiveRaSize() {
        int size = 0;
        for (int i = 0; i < this.leftArcFeatureAveragedWeights.length; ++i) {
            for (Object feat : this.rightArcFeatureAveragedWeights[i].keySet()) {
                for (float f : this.rightArcFeatureAveragedWeights[i].get(feat).getArray()) {
                    if (f == 0.0f) continue;
                    ++size;
                }
            }
        }
        return size;
    }

    public int laSize() {
        int size = 0;
        for (int i = 0; i < this.leftArcFeatureAveragedWeights.length; ++i) {
            for (Object feat : this.leftArcFeatureAveragedWeights[i].keySet()) {
                size += this.leftArcFeatureAveragedWeights[i].get(feat).length();
            }
        }
        return size;
    }

    public int effectiveLaSize() {
        int size = 0;
        for (int i = 0; i < this.leftArcFeatureAveragedWeights.length; ++i) {
            for (Object feat : this.leftArcFeatureAveragedWeights[i].keySet()) {
                for (float f : this.leftArcFeatureAveragedWeights[i].get(feat).getArray()) {
                    if (f == 0.0f) continue;
                    ++size;
                }
            }
        }
        return size;
    }
}

