/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.multilayer;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.apache.spark.Accumulator;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.DoubleFlatMapFunction;
import org.apache.spark.api.java.function.DoubleFunction;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.canova.api.records.reader.RecordReader;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.spark.canova.RecordReaderFunction;
import org.deeplearning4j.spark.impl.common.Adder;
import org.deeplearning4j.spark.impl.common.BestScoreAccumulator;
import org.deeplearning4j.spark.impl.common.gradient.GradientAdder;
import org.deeplearning4j.spark.impl.common.misc.GradientFromTupleFunction;
import org.deeplearning4j.spark.impl.common.misc.INDArrayFromTupleFunction;
import org.deeplearning4j.spark.impl.common.misc.ScoreReport;
import org.deeplearning4j.spark.impl.common.misc.UpdaterFromGradientTupleFunction;
import org.deeplearning4j.spark.impl.common.misc.UpdaterFromTupleFunction;
import org.deeplearning4j.spark.impl.common.updater.UpdaterAggregatorCombiner;
import org.deeplearning4j.spark.impl.common.updater.UpdaterElementCombiner;
import org.deeplearning4j.spark.impl.multilayer.IterativeReduceFlatMap;
import org.deeplearning4j.spark.impl.multilayer.ScoreFlatMapFunction;
import org.deeplearning4j.spark.impl.multilayer.evaluation.EvaluateFlatMapFunction;
import org.deeplearning4j.spark.impl.multilayer.evaluation.EvaluationReduceFunction;
import org.deeplearning4j.spark.impl.multilayer.gradientaccum.GradientAccumFlatMap;
import org.deeplearning4j.spark.impl.multilayer.scoring.ScoreExamplesFunction;
import org.deeplearning4j.spark.impl.multilayer.scoring.ScoreExamplesWithKeyFunction;
import org.deeplearning4j.spark.util.MLLibUtil;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple3;

public class SparkDl4jMultiLayer
implements Serializable {
    public static final int DEFAULT_EVAL_SCORE_BATCH_SIZE = 50;
    private transient SparkContext sparkContext;
    private transient JavaSparkContext sc;
    private MultiLayerConfiguration conf;
    private MultiLayerNetwork network;
    private Broadcast<INDArray> params;
    private Broadcast<Updater> updater;
    private boolean averageEachIteration = false;
    public static final String AVERAGE_EACH_ITERATION = "org.deeplearning4j.spark.iteration.average";
    public static final String ACCUM_GRADIENT = "org.deeplearning4j.spark.iteration.accumgrad";
    public static final String DIVIDE_ACCUM_GRADIENT = "org.deeplearning4j.spark.iteration.dividegrad";
    private Accumulator<Double> bestScoreAcc = null;
    private double lastScore;
    private transient boolean initDone = false;
    private transient AtomicInteger iterationsCount = new AtomicInteger(0);
    private List<IterationListener> listeners = new ArrayList<IterationListener>();
    private static final Logger log = LoggerFactory.getLogger(SparkDl4jMultiLayer.class);

    public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerNetwork network) {
        this(new JavaSparkContext(sparkContext), network);
    }

    public SparkDl4jMultiLayer(JavaSparkContext javaSparkContext, MultiLayerNetwork network) {
        this.sparkContext = javaSparkContext.sc();
        this.sc = javaSparkContext;
        this.conf = network.getLayerWiseConfigurations().clone();
        this.network = network;
        this.network.init();
        this.updater = this.sc.broadcast((Object)network.getUpdater());
        this.averageEachIteration = this.sparkContext.conf().getBoolean(AVERAGE_EACH_ITERATION, false);
        this.bestScoreAcc = BestScoreAccumulator.create(this.sparkContext);
    }

    public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerConfiguration conf) {
        this.sparkContext = sparkContext;
        this.sc = new JavaSparkContext(this.sparkContext);
        this.conf = conf.clone();
        this.network = new MultiLayerNetwork(conf);
        this.network.init();
        this.averageEachIteration = sparkContext.conf().getBoolean(AVERAGE_EACH_ITERATION, false);
        this.bestScoreAcc = BestScoreAccumulator.create(sparkContext);
        this.updater = this.sc.broadcast((Object)this.network.getUpdater());
    }

    public SparkDl4jMultiLayer(JavaSparkContext sc, MultiLayerConfiguration conf) {
        this(sc.sc(), conf);
    }

    public MultiLayerNetwork fit(String path, int labelIndex, RecordReader recordReader) {
        JavaRDD<DataSet> points = this.loadFromTextFile(path, labelIndex, recordReader);
        return this.fitDataSet(points);
    }

    public MultiLayerNetwork fit(String path, int labelIndex, RecordReader recordReader, int examplesPerFit, int numPartitions) {
        JavaRDD<DataSet> points = this.loadFromTextFile(path, labelIndex, recordReader);
        points.cache();
        int count = (int)points.count();
        return this.fitDataSet(points, examplesPerFit, count, numPartitions);
    }

    public MultiLayerNetwork fit(String path, int labelIndex, RecordReader recordReader, int examplesPerFit, int totalExamples, int numPartitions) {
        JavaRDD<DataSet> points = this.loadFromTextFile(path, labelIndex, recordReader);
        return this.fitDataSet(points, examplesPerFit, totalExamples, numPartitions);
    }

    private JavaRDD<DataSet> loadFromTextFile(String path, int labelIndex, RecordReader recordReader) {
        JavaRDD lines = this.sc.textFile(path);
        FeedForwardLayer outputLayer = (FeedForwardLayer)this.conf.getConf(this.conf.getConfs().size() - 1).getLayer();
        return lines.map((Function)new RecordReaderFunction(recordReader, labelIndex, outputLayer.getNOut()));
    }

    public MultiLayerNetwork getNetwork() {
        return this.network;
    }

    public void setNetwork(MultiLayerNetwork network) {
        this.network = network;
    }

    public Matrix predict(Matrix features) {
        return MLLibUtil.toMatrix(this.network.output(MLLibUtil.toMatrix(features)));
    }

    public Vector predict(Vector point) {
        return MLLibUtil.toVector(this.network.output(MLLibUtil.toVector(point)));
    }

    public MultiLayerNetwork fit(JavaRDD<LabeledPoint> rdd, int batchSize) {
        FeedForwardLayer outputLayer = (FeedForwardLayer)this.conf.getConf(this.conf.getConfs().size() - 1).getLayer();
        return this.fitDataSet(MLLibUtil.fromLabeledPoint(rdd, outputLayer.getNOut(), batchSize));
    }

    public MultiLayerNetwork fit(JavaSparkContext sc, JavaRDD<LabeledPoint> rdd) {
        FeedForwardLayer outputLayer = (FeedForwardLayer)this.conf.getConf(this.conf.getConfs().size() - 1).getLayer();
        return this.fitDataSet(MLLibUtil.fromLabeledPoint(sc, rdd, outputLayer.getNOut()));
    }

    public MultiLayerNetwork fitDataSet(JavaRDD<DataSet> rdd, int examplesPerFit, int numPartitions) {
        rdd.cache();
        int count = (int)rdd.count();
        return this.fitDataSet(rdd, examplesPerFit, count, numPartitions);
    }

    public MultiLayerNetwork fitDataSet(JavaRDD<DataSet> rdd, int examplesPerFit, int totalExamples, int numPartitions) {
        int nSplits = examplesPerFit == Integer.MAX_VALUE || examplesPerFit >= totalExamples ? 1 : (totalExamples % examplesPerFit == 0 ? totalExamples / examplesPerFit : totalExamples / examplesPerFit + 1);
        if (nSplits == 1) {
            this.fitDataSet(rdd);
        } else {
            double[] splitWeights = new double[nSplits];
            for (int i = 0; i < nSplits; ++i) {
                splitWeights[i] = 1.0 / (double)nSplits;
            }
            JavaRDD[] subsets = rdd.randomSplit(splitWeights);
            for (int i = 0; i < subsets.length; ++i) {
                log.info("Initiating distributed training of subset {} of {}", (Object)(i + 1), (Object)subsets.length);
                JavaRDD next = subsets[i].repartition(numPartitions);
                this.fitDataSet((JavaRDD<DataSet>)next);
            }
        }
        return this.network;
    }

    public MultiLayerNetwork fitDataSet(JavaRDD<DataSet> rdd) {
        int iterations = this.conf.getConf(0).getNumIterations();
        log.info("Running distributed training:  (averaging each iteration = " + this.averageEachIteration + "), (iterations = " + iterations + "), (num partions = " + rdd.partitions().size() + ")");
        if (!this.averageEachIteration) {
            this.runIteration(rdd);
        } else {
            for (NeuralNetConfiguration conf : this.conf.getConfs()) {
                conf.setNumIterations(1);
            }
            for (int i = 0; i < iterations; ++i) {
                this.runIteration(rdd);
            }
            if (iterations > 1) {
                for (NeuralNetConfiguration conf : this.conf.getConfs()) {
                    conf.setNumIterations(iterations);
                }
            }
        }
        return this.network;
    }

    protected void runIteration(JavaRDD<DataSet> rdd) {
        int maxRep = 0;
        long maxSm = 0L;
        int paramsLength = this.network.numParams(false);
        log.info("Broadcasting initial parameters of length " + paramsLength);
        INDArray valToBroadcast = this.network.params(false);
        this.params = this.sc.broadcast((Object)valToBroadcast);
        Updater updater = this.network.getUpdater();
        if (updater == null) {
            this.network.setUpdater(UpdaterCreator.getUpdater((Model)this.network));
            log.warn("Unable to propagate null updater");
            updater = this.network.getUpdater();
        }
        this.updater = this.sc.broadcast((Object)updater);
        boolean accumGrad = this.sc.getConf().getBoolean(ACCUM_GRADIENT, false);
        if (accumGrad) {
            JavaRDD results = rdd.mapPartitions((FlatMapFunction)new GradientAccumFlatMap(this.conf.toJson(), this.params, this.updater), true).cache();
            JavaRDD resultsGradient = results.map((Function)new GradientFromTupleFunction());
            log.info("Ran iterative reduce... averaging results now.");
            GradientAdder a = new GradientAdder(paramsLength);
            resultsGradient.foreach((VoidFunction)a);
            INDArray accumulatedGradient = (INDArray)a.getAccumulator().value();
            boolean divideGrad = this.sc.getConf().getBoolean(DIVIDE_ACCUM_GRADIENT, false);
            if (divideGrad) {
                maxRep = results.partitions().size();
                accumulatedGradient.divi((Number)maxRep);
            }
            log.info("Accumulated parameters");
            log.info("Summed gradients.");
            this.network.setParameters(this.network.params(false).addi(accumulatedGradient));
            log.info("Set parameters");
            JavaDoubleRDD scores = results.mapToDouble((DoubleFunction)new ScoreMappingG());
            this.lastScore = scores.mean();
            if (!this.initDone) {
                JavaDoubleRDD sm = results.mapToDouble((DoubleFunction)new SMappingG());
                maxSm = sm.mean().longValue();
            }
            log.info("Processing updaters");
            JavaRDD resultsUpdater = results.map((Function)new UpdaterFromGradientTupleFunction());
            UpdaterAggregator aggregator = (UpdaterAggregator)resultsUpdater.aggregate((Object)((Updater)resultsUpdater.first()).getAggregator(false), (Function2)new UpdaterElementCombiner(), (Function2)new UpdaterAggregatorCombiner());
            Updater combinedUpdater = aggregator.getUpdater();
            this.network.setUpdater(combinedUpdater);
            log.info("Set updater");
        } else {
            JavaRDD results = rdd.mapPartitions((FlatMapFunction)new IterativeReduceFlatMap(this.conf.toJson(), this.params, this.updater, this.bestScoreAcc), true).cache();
            JavaRDD resultsParams = results.map((Function)new INDArrayFromTupleFunction());
            log.info("Running iterative reduce and averaging parameters");
            Adder a = new Adder(paramsLength, (Accumulator<Integer>)this.sc.accumulator(0));
            resultsParams.foreach((VoidFunction)a);
            INDArray newParams = (INDArray)a.getAccumulator().value();
            maxRep = (Integer)a.getCounter().value();
            newParams.divi((Number)maxRep);
            this.network.setParameters(newParams);
            log.info("Accumulated and set parameters");
            JavaDoubleRDD scores = results.mapToDouble((DoubleFunction)new ScoreMapping());
            this.lastScore = scores.mean();
            if (!this.initDone) {
                JavaDoubleRDD sm = results.mapToDouble((DoubleFunction)new SMapping());
                maxSm = sm.mean().longValue();
            }
            JavaRDD resultsUpdater = results.map((Function)new UpdaterFromTupleFunction());
            UpdaterAggregator aggregator = (UpdaterAggregator)resultsUpdater.aggregate(null, (Function2)new UpdaterElementCombiner(), (Function2)new UpdaterAggregatorCombiner());
            Updater combinedUpdater = aggregator.getUpdater();
            this.network.setUpdater(combinedUpdater);
            log.info("Processed and set updater");
        }
        if (this.listeners.size() > 0) {
            log.debug("Invoking IterationListeners");
            this.network.setScore(this.lastScore);
            this.invokeListeners(this.network, this.iterationsCount.incrementAndGet());
        }
        if (!this.initDone) {
            this.initDone = true;
            this.update(maxRep, maxSm);
        }
    }

    public static MultiLayerNetwork train(JavaRDD<LabeledPoint> data, MultiLayerConfiguration conf) {
        SparkDl4jMultiLayer multiLayer = new SparkDl4jMultiLayer(data.context(), conf);
        return multiLayer.fit(new JavaSparkContext(data.context()), data);
    }

    public void setListeners(@NonNull Collection<IterationListener> listeners) {
        if (listeners == null) {
            throw new NullPointerException("listeners");
        }
        this.listeners.clear();
        this.listeners.addAll(listeners);
    }

    protected void invokeListeners(MultiLayerNetwork network, int iteration) {
        for (IterationListener listener : this.listeners) {
            try {
                listener.iterationDone((Model)network, iteration);
            }
            catch (Exception e) {
                log.error("Exception caught at IterationListener invocation" + e.getMessage());
                e.printStackTrace();
            }
        }
    }

    public double getScore() {
        return this.lastScore;
    }

    public double calculateScore(JavaRDD<DataSet> data, boolean average) {
        long n = data.count();
        JavaRDD scores = data.mapPartitions((FlatMapFunction)new ScoreFlatMapFunction(this.conf.toJson(), (Broadcast<INDArray>)this.sc.broadcast((Object)this.network.params(false))));
        List scoresList = scores.collect();
        double sum = 0.0;
        for (Double d : scoresList) {
            sum += d.doubleValue();
        }
        if (average) {
            return sum / (double)n;
        }
        return sum;
    }

    public JavaDoubleRDD scoreExamples(JavaRDD<DataSet> data, boolean includeRegularizationTerms) {
        return this.scoreExamples(data, includeRegularizationTerms, 50);
    }

    public JavaDoubleRDD scoreExamples(JavaRDD<DataSet> data, boolean includeRegularizationTerms, int batchSize) {
        return data.mapPartitionsToDouble((DoubleFlatMapFunction)new ScoreExamplesFunction((Broadcast<INDArray>)this.sc.broadcast((Object)this.network.params()), (Broadcast<String>)this.sc.broadcast((Object)this.conf.toJson()), includeRegularizationTerms, batchSize));
    }

    public <K> JavaPairRDD<K, Double> scoreExamples(JavaPairRDD<K, DataSet> data, boolean includeRegularizationTerms) {
        return this.scoreExamples(data, includeRegularizationTerms, 50);
    }

    public <K> JavaPairRDD<K, Double> scoreExamples(JavaPairRDD<K, DataSet> data, boolean includeRegularizationTerms, int batchSize) {
        return data.mapPartitionsToPair(new ScoreExamplesWithKeyFunction((Broadcast<INDArray>)this.sc.broadcast((Object)this.network.params()), (Broadcast<String>)this.sc.broadcast((Object)this.conf.toJson()), includeRegularizationTerms, batchSize));
    }

    public Evaluation evaluate(JavaRDD<DataSet> data) {
        return this.evaluate(data, null);
    }

    public Evaluation evaluate(JavaRDD<DataSet> data, List<String> labelsList) {
        return this.evaluate(data, labelsList, 50);
    }

    private void update(int mr, long mg) {
        Environment env = EnvironmentUtils.buildEnvironment();
        env.setNumCores(mr);
        env.setAvailableMemory(mg);
        Task task = ModelSerializer.taskByModel((Model)this.network);
        Heartbeat.getInstance().reportEvent(Event.SPARK, env, task);
    }

    public Evaluation evaluate(JavaRDD<DataSet> data, List<String> labelsList, int evalBatchSize) {
        Broadcast listBroadcast = labelsList == null ? null : this.sc.broadcast(labelsList);
        JavaRDD evaluations = data.mapPartitions((FlatMapFunction)new EvaluateFlatMapFunction((Broadcast<String>)this.sc.broadcast((Object)this.conf.toJson()), (Broadcast<INDArray>)this.sc.broadcast((Object)this.network.params()), evalBatchSize, (Broadcast<List<String>>)listBroadcast));
        return (Evaluation)evaluations.reduce((Function2)new EvaluationReduceFunction());
    }

    private static class SMappingG
    implements DoubleFunction<Tuple3<Gradient, Updater, ScoreReport>> {
        private SMappingG() {
        }

        public double call(Tuple3<Gradient, Updater, ScoreReport> t3) throws Exception {
            return ((ScoreReport)t3._3()).getM();
        }
    }

    private static class SMapping
    implements DoubleFunction<Tuple3<INDArray, Updater, ScoreReport>> {
        private SMapping() {
        }

        public double call(Tuple3<INDArray, Updater, ScoreReport> t3) throws Exception {
            return ((ScoreReport)t3._3()).getM();
        }
    }

    private static class ScoreMappingG
    implements DoubleFunction<Tuple3<Gradient, Updater, ScoreReport>> {
        private ScoreMappingG() {
        }

        public double call(Tuple3<Gradient, Updater, ScoreReport> t3) throws Exception {
            return ((ScoreReport)t3._3()).getS();
        }
    }

    private static class ScoreMapping
    implements DoubleFunction<Tuple3<INDArray, Updater, ScoreReport>> {
        private ScoreMapping() {
        }

        public double call(Tuple3<INDArray, Updater, ScoreReport> t3) throws Exception {
            return ((ScoreReport)t3._3()).getS();
        }
    }
}

