/*
 * Decompiled with CFR 0.152.
 */
package org.languagetool.rules.neuralnetwork;

import java.io.InputStream;
import org.languagetool.rules.neuralnetwork.Classifier;
import org.languagetool.rules.neuralnetwork.Embedding;
import org.languagetool.rules.neuralnetwork.Matrix;

public class TwoLayerClassifier
implements Classifier {
    private final Embedding embedding;
    private final Matrix W_fc1;
    private final Matrix b_fc1;
    private final Matrix W_fc2;
    private final Matrix b_fc2;

    public TwoLayerClassifier(Embedding embedding, InputStream W1, InputStream b1, InputStream W2, InputStream b2) {
        this.embedding = embedding;
        this.W_fc1 = new Matrix(W1);
        this.b_fc1 = new Matrix(b1).transpose();
        this.W_fc2 = new Matrix(W2);
        this.b_fc2 = new Matrix(b2).transpose();
    }

    @Override
    public float[] getScores(String[] context) {
        return this.embedding.lookup(context).mul(this.W_fc1).add(this.b_fc1).relu().mul(this.W_fc2).add(this.b_fc2).row(0);
    }
}

