/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class MultiDataSet
implements org.nd4j.linalg.dataset.api.MultiDataSet {
    private INDArray[] features;
    private INDArray[] labels;
    private INDArray[] featuresMaskArrays;
    private INDArray[] labelsMaskArrays;

    public MultiDataSet(INDArray features, INDArray labels) {
        this(new INDArray[]{features}, new INDArray[]{labels});
    }

    public MultiDataSet(INDArray[] features, INDArray[] labels) {
        this(features, labels, null, null);
    }

    public MultiDataSet(INDArray[] features, INDArray[] labels, INDArray[] featuresMaskArrays, INDArray[] labelsMaskArrays) {
        if (features != null && featuresMaskArrays != null && features.length != featuresMaskArrays.length) {
            throw new IllegalArgumentException("Invalid features / features mask arrays combination: features and features mask arrays must not be different lengths");
        }
        if (labels != null && labelsMaskArrays != null && labels.length != labelsMaskArrays.length) {
            throw new IllegalArgumentException("Invalid labels / labels mask arrays combination: labels and labels mask arrays must not be different lengths");
        }
        this.features = features;
        this.labels = labels;
        this.featuresMaskArrays = featuresMaskArrays;
        this.labelsMaskArrays = labelsMaskArrays;
    }

    @Override
    public int numFeatureArrays() {
        return this.features != null ? this.features.length : 0;
    }

    @Override
    public int numLabelsArrays() {
        return this.labels != null ? this.labels.length : 0;
    }

    @Override
    public INDArray[] getFeatures() {
        return this.features;
    }

    @Override
    public INDArray getFeatures(int index) {
        return this.features[index];
    }

    @Override
    public void setFeatures(INDArray[] features) {
        this.features = features;
    }

    @Override
    public void setFeatures(int idx, INDArray features) {
        this.features[idx] = features;
    }

    @Override
    public INDArray[] getLabels() {
        return this.labels;
    }

    @Override
    public INDArray getLabels(int index) {
        return this.labels[index];
    }

    @Override
    public void setLabels(INDArray[] labels) {
        this.labels = labels;
    }

    @Override
    public void setLabels(int idx, INDArray labels) {
        this.labels[idx] = labels;
    }

    @Override
    public boolean hasMaskArrays() {
        if (this.featuresMaskArrays == null && this.labelsMaskArrays == null) {
            return false;
        }
        if (this.featuresMaskArrays != null) {
            for (INDArray i : this.featuresMaskArrays) {
                if (i == null) continue;
                return true;
            }
        }
        if (this.labelsMaskArrays != null) {
            for (INDArray i : this.labelsMaskArrays) {
                if (i == null) continue;
                return true;
            }
        }
        return false;
    }

    @Override
    public INDArray[] getFeaturesMaskArrays() {
        return this.featuresMaskArrays;
    }

    @Override
    public INDArray getFeaturesMaskArray(int index) {
        return this.featuresMaskArrays != null ? this.featuresMaskArrays[index] : null;
    }

    @Override
    public void setFeaturesMaskArrays(INDArray[] maskArrays) {
        this.featuresMaskArrays = maskArrays;
    }

    @Override
    public void setFeaturesMaskArray(int idx, INDArray maskArray) {
        this.featuresMaskArrays[idx] = maskArray;
    }

    @Override
    public INDArray[] getLabelsMaskArrays() {
        return this.labelsMaskArrays;
    }

    @Override
    public INDArray getLabelsMaskArray(int index) {
        return this.labelsMaskArrays != null ? this.labelsMaskArrays[index] : null;
    }

    @Override
    public void setLabelsMaskArray(INDArray[] labelsMaskArrays) {
        this.labelsMaskArrays = labelsMaskArrays;
    }

    @Override
    public void setLabelsMaskArray(int idx, INDArray labelsMaskArray) {
        this.labelsMaskArrays[idx] = labelsMaskArray;
    }

    public static MultiDataSet merge(Collection<? extends org.nd4j.linalg.dataset.api.MultiDataSet> toMerge) {
        if (toMerge.size() == 1) {
            org.nd4j.linalg.dataset.api.MultiDataSet mds = toMerge.iterator().next();
            if (mds instanceof MultiDataSet) {
                return (MultiDataSet)mds;
            }
            return new MultiDataSet(mds.getFeatures(), mds.getLabels(), mds.getFeaturesMaskArrays(), mds.getLabelsMaskArrays());
        }
        ArrayList<? extends org.nd4j.linalg.dataset.api.MultiDataSet> list = toMerge instanceof List ? (ArrayList<? extends org.nd4j.linalg.dataset.api.MultiDataSet>)toMerge : new ArrayList<org.nd4j.linalg.dataset.api.MultiDataSet>(toMerge);
        int nInArrays = ((org.nd4j.linalg.dataset.api.MultiDataSet)list.get(0)).numFeatureArrays();
        int nOutArrays = ((org.nd4j.linalg.dataset.api.MultiDataSet)list.get(0)).numLabelsArrays();
        INDArray[][] features = new INDArray[list.size()][0];
        INDArray[][] labels = new INDArray[list.size()][0];
        INDArray[][] featuresMasks = new INDArray[list.size()][0];
        INDArray[][] labelsMasks = new INDArray[list.size()][0];
        int i = 0;
        for (org.nd4j.linalg.dataset.api.MultiDataSet multiDataSet : list) {
            features[i] = multiDataSet.getFeatures();
            labels[i] = multiDataSet.getLabels();
            featuresMasks[i] = multiDataSet.getFeaturesMaskArrays();
            labelsMasks[i] = multiDataSet.getLabelsMaskArrays();
            if (features[i] == null || features[i].length != nInArrays) {
                throw new IllegalStateException("Cannot merge MultiDataSets with different number of input arrays: toMerge[0] has " + nInArrays + " input arrays; toMerge[" + i + "] has " + (features[i] != null ? Integer.valueOf(features[i].length) : null) + " arrays");
            }
            if (labels[i] == null || labels[i].length != nOutArrays) {
                throw new IllegalStateException("Cannot merge MultiDataSets with different number of output arrays: toMerge[0] has " + nOutArrays + " output arrays; toMerge[" + i + "] has " + (labels[i] != null ? Integer.valueOf(labels[i].length) : null) + " arrays");
            }
            ++i;
        }
        INDArray[] mergedFeatures = new INDArray[nInArrays];
        INDArray[] iNDArrayArray = new INDArray[nOutArrays];
        INDArray[] mergedFeaturesMasks = new INDArray[nInArrays];
        INDArray[] mergedLabelsMasks = new INDArray[nOutArrays];
        boolean needFeaturesMasks = false;
        for (i = 0; i < nInArrays; ++i) {
            Pair<INDArray, INDArray> pair = MultiDataSet.merge(features, featuresMasks, i);
            mergedFeatures[i] = (INDArray)pair.getFirst();
            mergedFeaturesMasks[i] = (INDArray)pair.getSecond();
            if (mergedFeaturesMasks[i] == null) continue;
            needFeaturesMasks = true;
        }
        if (!needFeaturesMasks) {
            mergedFeaturesMasks = null;
        }
        boolean needLabelsMasks = false;
        for (i = 0; i < nOutArrays; ++i) {
            Pair<INDArray, INDArray> pair = MultiDataSet.merge(labels, labelsMasks, i);
            iNDArrayArray[i] = (INDArray)pair.getFirst();
            mergedLabelsMasks[i] = (INDArray)pair.getSecond();
            if (mergedLabelsMasks[i] == null) continue;
            needLabelsMasks = true;
        }
        if (!needLabelsMasks) {
            mergedLabelsMasks = null;
        }
        return new MultiDataSet(mergedFeatures, iNDArrayArray, mergedFeaturesMasks, mergedLabelsMasks);
    }

    private static Pair<INDArray, INDArray> merge(INDArray[][] arrays, INDArray[][] masks, int column) {
        int rank = arrays[column][0].rank();
        if (rank == 2) {
            return new Pair((Object)MultiDataSet.merge2d(arrays, column), null);
        }
        if (rank == 3) {
            return MultiDataSet.mergeTimeSeries(arrays, masks, column);
        }
        if (rank == 4) {
            return new Pair((Object)MultiDataSet.merge4d(arrays, column), null);
        }
        throw new UnsupportedOperationException("Cannot merge arrays with rank 5 or more (input/output number: " + column + ")");
    }

    private static INDArray merge2d(INDArray[][] arrays, int inOutIdx) {
        int nExamples = 0;
        int cols = arrays[0][inOutIdx].columns();
        for (int i = 0; i < arrays.length; ++i) {
            nExamples += arrays[i][inOutIdx].rows();
            if (arrays[i][inOutIdx].columns() == cols) continue;
            throw new IllegalStateException("Cannot merge 2d arrays with different numbers of columns (firstNCols=" + cols + ", ithNCols=" + arrays[i][inOutIdx].columns() + ")");
        }
        INDArray out = Nd4j.create(nExamples, cols);
        int rowsSoFar = 0;
        for (int i = 0; i < arrays.length; ++i) {
            int thisRows = arrays[i][inOutIdx].rows();
            out.put(new INDArrayIndex[]{NDArrayIndex.interval(rowsSoFar, rowsSoFar + thisRows), NDArrayIndex.all()}, arrays[i][inOutIdx]);
            rowsSoFar += thisRows;
        }
        return out;
    }

    private static Pair<INDArray, INDArray> mergeTimeSeries(INDArray[][] arrays, INDArray[][] masks, int inOutIdx) {
        int firstLength = arrays[0][inOutIdx].size(2);
        int size = arrays[0][inOutIdx].size(1);
        int maxLength = firstLength;
        boolean hasMask = false;
        boolean lengthsDiffer = false;
        int totalExamples = 0;
        for (int i = 0; i < arrays.length; ++i) {
            totalExamples += arrays[i][inOutIdx].size(0);
            int thisLength = arrays[i][inOutIdx].size(2);
            maxLength = Math.max(maxLength, thisLength);
            if (thisLength != firstLength) {
                lengthsDiffer = true;
            }
            if (masks != null && masks[i] != null && masks[i][inOutIdx] != null) {
                hasMask = true;
            }
            if (arrays[i][inOutIdx].size(1) == size) continue;
            throw new IllegalStateException("Cannot merge time series with different size for dimension 1 (first shape: " + Arrays.toString(arrays[0][inOutIdx].shape()) + ", " + i + "th shape: " + Arrays.toString(arrays[i][inOutIdx].shape()));
        }
        boolean needMask = hasMask || lengthsDiffer;
        INDArray arr = Nd4j.create(totalExamples, size, maxLength);
        INDArray mask = needMask ? Nd4j.ones(totalExamples, maxLength) : null;
        int examplesSoFar = 0;
        if (!lengthsDiffer && !needMask) {
            for (int i = 0; i < arrays.length; ++i) {
                int thisNExamples = arrays[i][inOutIdx].size(0);
                arr.put(new INDArrayIndex[]{NDArrayIndex.interval(examplesSoFar, examplesSoFar + thisNExamples), NDArrayIndex.all(), NDArrayIndex.all()}, arrays[i][inOutIdx]);
                examplesSoFar += thisNExamples;
            }
            return new Pair((Object)arr, null);
        }
        for (int i = 0; i < arrays.length; ++i) {
            INDArray a = arrays[i][inOutIdx];
            int thisNExamples = a.size(0);
            int thisLength = a.size(2);
            arr.put(new INDArrayIndex[]{NDArrayIndex.interval(examplesSoFar, examplesSoFar + thisNExamples), NDArrayIndex.all(), NDArrayIndex.interval(0, thisLength)}, a);
            if (masks != null && masks[i] != null && masks[i][inOutIdx] != null) {
                INDArray origMask = masks[i][inOutIdx];
                int maskLength = origMask.size(1);
                mask.put(new INDArrayIndex[]{NDArrayIndex.interval(examplesSoFar, examplesSoFar + thisNExamples), NDArrayIndex.interval(0, maskLength)}, origMask);
                if (maskLength < maxLength) {
                    mask.put(new INDArrayIndex[]{NDArrayIndex.interval(examplesSoFar, examplesSoFar + thisNExamples), NDArrayIndex.interval(maskLength, maxLength)}, Nd4j.zeros(thisNExamples, maxLength - maskLength));
                }
            } else if (thisLength < maxLength) {
                mask.put(new INDArrayIndex[]{NDArrayIndex.interval(examplesSoFar, examplesSoFar + thisNExamples), NDArrayIndex.interval(thisLength, maxLength)}, Nd4j.zeros(thisNExamples, maxLength - thisLength));
            }
            examplesSoFar += thisNExamples;
        }
        return new Pair((Object)arr, (Object)mask);
    }

    private static INDArray merge4d(INDArray[][] arrays, int inOutIdx) {
        int nExamples = 0;
        int[] shape = arrays[0][inOutIdx].shape();
        for (int i = 0; i < arrays.length; ++i) {
            nExamples += arrays[i][inOutIdx].size(0);
            int[] thisShape = arrays[i][inOutIdx].shape();
            if (thisShape.length != 4) {
                throw new IllegalStateException("Cannot merge 4d arrays with non 4d arrays");
            }
            for (int j = 1; j < 4; ++j) {
                if (thisShape[j] == shape[j]) continue;
                throw new IllegalStateException("Cannot merge 4d arrays with different shape (other than # examples):  data[0][" + inOutIdx + "].shape = " + Arrays.toString(shape) + ", data[" + i + "][" + inOutIdx + "].shape = " + Arrays.toString(thisShape));
            }
        }
        INDArray out = Nd4j.create(nExamples, shape[1], shape[2], shape[3]);
        int rowsSoFar = 0;
        for (int i = 0; i < arrays.length; ++i) {
            int thisRows = arrays[i][inOutIdx].size(0);
            out.put(new INDArrayIndex[]{NDArrayIndex.interval(rowsSoFar, rowsSoFar + thisRows), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()}, arrays[i][inOutIdx]);
            rowsSoFar += thisRows;
        }
        return out;
    }
}

