/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.canova;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.spark.api.java.function.Function;
import org.canova.api.io.WritableConverter;
import org.canova.api.io.converters.WritableConverterException;
import org.canova.api.writable.Writable;
import org.canova.common.data.NDArrayWritable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

public class CanovaDataSetFunction
implements Function<Collection<Writable>, org.nd4j.linalg.dataset.DataSet>,
Serializable {
    private final int labelIndex;
    private final int numPossibleLabels;
    private final boolean regression;
    private final DataSetPreProcessor preProcessor;
    private final WritableConverter converter;
    protected int batchSize = -1;

    public CanovaDataSetFunction(int labelIndex, int numPossibleLabels, boolean regression) {
        this(labelIndex, numPossibleLabels, regression, null, null);
    }

    public CanovaDataSetFunction(int labelIndex, int numPossibleLabels, boolean regression, DataSetPreProcessor preProcessor, WritableConverter converter) {
        this.labelIndex = labelIndex;
        this.numPossibleLabels = numPossibleLabels;
        this.regression = regression;
        this.preProcessor = preProcessor;
        this.converter = converter;
    }

    public org.nd4j.linalg.dataset.DataSet call(Collection<Writable> writables) throws Exception {
        ArrayList<Writable> list = writables instanceof List ? (ArrayList<Writable>)writables : new ArrayList<Writable>(writables);
        int labelIndex = this.labelIndex;
        if (this.numPossibleLabels >= 1 && labelIndex < 0) {
            labelIndex = list.size() - 1;
        }
        INDArray label = null;
        INDArray featureVector = null;
        int featureCount = 0;
        for (int j = 0; j < list.size(); ++j) {
            Writable current = (Writable)list.get(j);
            if (this.converter != null) {
                current = this.converter.convert(current);
            }
            if (labelIndex >= 0 && j == labelIndex) {
                if (this.converter != null) {
                    try {
                        current = this.converter.convert(current);
                    }
                    catch (WritableConverterException e) {
                        e.printStackTrace();
                    }
                }
                if (this.numPossibleLabels < 1) {
                    throw new IllegalStateException("Number of possible labels invalid, must be >= 1");
                }
                if (this.regression) {
                    label = Nd4j.scalar((double)current.toDouble());
                    continue;
                }
                int curr = current.toInt();
                if (curr >= this.numPossibleLabels) {
                    throw new IllegalStateException("Invalid input: class label is " + curr + " with numPossibleLables = " + this.numPossibleLabels + " (class label must be 0 <= labelIdx < numPossibleLabels)");
                }
                label = FeatureUtil.toOutcomeVector((int)curr, (int)this.numPossibleLabels);
                continue;
            }
            try {
                double value = current.toDouble();
                if (featureVector == null) {
                    featureVector = Nd4j.create((int)(labelIndex >= 0 ? list.size() - 1 : list.size()));
                }
                featureVector.putScalar(featureCount++, value);
                continue;
            }
            catch (UnsupportedOperationException e) {
                if (current instanceof NDArrayWritable) {
                    assert (featureVector == null);
                    featureVector = ((NDArrayWritable)current).get();
                    continue;
                }
                throw e;
            }
        }
        org.nd4j.linalg.dataset.DataSet ds = new org.nd4j.linalg.dataset.DataSet(featureVector, labelIndex >= 0 ? label : featureVector);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((DataSet)ds);
        }
        return ds;
    }
}

