/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.evaluation;

import java.util.Arrays;
import java.util.Iterator;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import org.tribuo.Dataset;
import org.tribuo.Output;
import org.tribuo.dataset.DatasetView;
import org.tribuo.util.Util;

public class KFoldSplitter<T extends Output<T>> {
    private static final Logger logger = Logger.getLogger(KFoldSplitter.class.getName());
    protected final int nsplits;
    protected final long seed;
    protected final SplittableRandom rng;

    public KFoldSplitter(int nsplits, long randomSeed) {
        if (nsplits < 2) {
            throw new IllegalArgumentException("nsplits must be at least 2");
        }
        this.nsplits = nsplits;
        this.seed = randomSeed;
        this.rng = new SplittableRandom(randomSeed);
    }

    public KFoldSplitter(int nsplits) {
        this(nsplits, 12345L);
    }

    public Iterator<TrainTestFold<T>> split(final Dataset<T> dataset, boolean shuffle) {
        final int nsamples = dataset.size();
        if (nsamples == 0) {
            throw new IllegalArgumentException("empty input data");
        }
        if (this.nsplits > nsamples) {
            throw new IllegalArgumentException("cannot have nsplits > nsamples");
        }
        final int[] indices = shuffle ? Util.randperm(nsamples, this.rng) : IntStream.range(0, nsamples).toArray();
        final int[] foldSizes = new int[this.nsplits];
        Arrays.fill(foldSizes, nsamples / this.nsplits);
        int i = 0;
        while (i < nsamples % this.nsplits) {
            int n = i++;
            foldSizes[n] = foldSizes[n] + 1;
        }
        return new Iterator<TrainTestFold<T>>(){
            int foldPtr = 0;
            int dataPtr = 0;

            @Override
            public boolean hasNext() {
                return this.foldPtr < foldSizes.length;
            }

            @Override
            public TrainTestFold<T> next() {
                int stop;
                int size = foldSizes[this.foldPtr];
                ++this.foldPtr;
                int start = this.dataPtr;
                this.dataPtr = stop = this.dataPtr + size;
                int[] holdOut = Arrays.copyOfRange(indices, start, stop);
                int[] rest = new int[indices.length - holdOut.length];
                System.arraycopy(indices, 0, rest, 0, start);
                System.arraycopy(indices, stop, rest, start, nsamples - stop);
                return new TrainTestFold(new DatasetView(dataset, rest, "TrainFold(seed=" + KFoldSplitter.this.seed + "," + this.foldPtr + " of " + KFoldSplitter.this.nsplits + ")"), new DatasetView(dataset, holdOut, "TestFold(seed=" + KFoldSplitter.this.seed + "," + this.foldPtr + " of " + KFoldSplitter.this.nsplits + ")"));
            }
        };
    }

    public static class TrainTestFold<T extends Output<T>> {
        public final DatasetView<T> train;
        public final DatasetView<T> test;

        TrainTestFold(DatasetView<T> train, DatasetView<T> test) {
            this.train = train;
            this.test = test;
        }
    }
}

