/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.crypto.sphincsplus;

import java.util.LinkedList;
import org.bouncycastle.pqc.crypto.sphincsplus.ADRS;
import org.bouncycastle.pqc.crypto.sphincsplus.NodeEntry;
import org.bouncycastle.pqc.crypto.sphincsplus.SIG_FORS;
import org.bouncycastle.pqc.crypto.sphincsplus.SPHINCSPlusEngine;
import org.bouncycastle.util.Arrays;

class Fors {
    SPHINCSPlusEngine engine;

    public Fors(SPHINCSPlusEngine engine) {
        this.engine = engine;
    }

    byte[] treehash(byte[] skSeed, int s, int z, byte[] pkSeed, ADRS adrsParam) {
        if (s >>> z << z != s) {
            return null;
        }
        LinkedList<NodeEntry> stack = new LinkedList<NodeEntry>();
        ADRS adrs = new ADRS(adrsParam);
        for (int idx = 0; idx < 1 << z; ++idx) {
            adrs.setTypeAndClear(6);
            adrs.setKeyPairAddress(adrsParam.getKeyPairAddress());
            adrs.setTreeHeight(0);
            adrs.setTreeIndex(s + idx);
            byte[] sk = this.engine.PRF(pkSeed, skSeed, adrs);
            adrs.changeType(3);
            byte[] node = this.engine.F(pkSeed, adrs, sk);
            adrs.setTreeHeight(1);
            int adrsTreeHeight = 1;
            int adrsTreeIndex = s + idx;
            while (!stack.isEmpty() && ((NodeEntry)stack.get((int)0)).nodeHeight == adrsTreeHeight) {
                adrsTreeIndex = (adrsTreeIndex - 1) / 2;
                adrs.setTreeIndex(adrsTreeIndex);
                NodeEntry current = (NodeEntry)stack.remove(0);
                node = this.engine.H(pkSeed, adrs, current.nodeValue, node);
                adrs.setTreeHeight(++adrsTreeHeight);
            }
            stack.add(0, new NodeEntry(node, adrsTreeHeight));
        }
        return ((NodeEntry)stack.get((int)0)).nodeValue;
    }

    public SIG_FORS[] sign(byte[] md, byte[] skSeed, byte[] pkSeed, ADRS paramAdrs) {
        ADRS adrs = new ADRS(paramAdrs);
        int[] idxs = Fors.message_to_idxs(md, this.engine.K, this.engine.A);
        SIG_FORS[] sig_fors = new SIG_FORS[this.engine.K];
        int t = this.engine.T;
        for (int i = 0; i < this.engine.K; ++i) {
            int idx = idxs[i];
            adrs.setTypeAndClear(6);
            adrs.setKeyPairAddress(paramAdrs.getKeyPairAddress());
            adrs.setTreeHeight(0);
            adrs.setTreeIndex(i * t + idx);
            byte[] sk = this.engine.PRF(pkSeed, skSeed, adrs);
            adrs.changeType(3);
            byte[][] authPath = new byte[this.engine.A][];
            for (int j = 0; j < this.engine.A; ++j) {
                int s = idx / (1 << j) ^ 1;
                authPath[j] = this.treehash(skSeed, i * t + s * (1 << j), j, pkSeed, adrs);
            }
            sig_fors[i] = new SIG_FORS(sk, authPath);
        }
        return sig_fors;
    }

    public byte[] pkFromSig(SIG_FORS[] sig_fors, byte[] message, byte[] pkSeed, ADRS adrs) {
        byte[][] node = new byte[2][];
        byte[][] root = new byte[this.engine.K][];
        int t = this.engine.T;
        int[] idxs = Fors.message_to_idxs(message, this.engine.K, this.engine.A);
        for (int i = 0; i < this.engine.K; ++i) {
            int idx = idxs[i];
            byte[] sk = sig_fors[i].getSK();
            adrs.setTreeHeight(0);
            adrs.setTreeIndex(i * t + idx);
            node[0] = this.engine.F(pkSeed, adrs, sk);
            byte[][] authPath = sig_fors[i].getAuthPath();
            adrs.setTreeIndex(i * t + idx);
            for (int j = 0; j < this.engine.A; ++j) {
                adrs.setTreeHeight(j + 1);
                if (idx / (1 << j) % 2 == 0) {
                    adrs.setTreeIndex(adrs.getTreeIndex() / 2);
                    node[1] = this.engine.H(pkSeed, adrs, node[0], authPath[j]);
                } else {
                    adrs.setTreeIndex((adrs.getTreeIndex() - 1) / 2);
                    node[1] = this.engine.H(pkSeed, adrs, authPath[j], node[0]);
                }
                node[0] = node[1];
            }
            root[i] = node[0];
        }
        ADRS forspkADRS = new ADRS(adrs);
        forspkADRS.setTypeAndClear(4);
        forspkADRS.setKeyPairAddress(adrs.getKeyPairAddress());
        return this.engine.T_l(pkSeed, forspkADRS, Arrays.concatenate(root));
    }

    static int[] message_to_idxs(byte[] msg, int fors_trees, int fors_height) {
        int offset = 0;
        int[] idxs = new int[fors_trees];
        for (int i = 0; i < fors_trees; ++i) {
            idxs[i] = 0;
            for (int j = 0; j < fors_height; ++j) {
                int n = i;
                idxs[n] = idxs[n] ^ (msg[offset >> 3] >> (offset & 7) & 1) << j;
                ++offset;
            }
        }
        return idxs;
    }
}

