/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.computationgraph;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.apache.spark.Accumulator;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.DoubleFlatMapFunction;
import org.apache.spark.api.java.function.DoubleFunction;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.broadcast.Broadcast;
import org.canova.api.records.reader.RecordReader;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.spark.canova.RecordReaderFunction;
import org.deeplearning4j.spark.impl.common.Adder;
import org.deeplearning4j.spark.impl.common.gradient.GradientAdder;
import org.deeplearning4j.spark.impl.common.misc.GradientFromTupleFunctionCG;
import org.deeplearning4j.spark.impl.common.misc.INDArrayFromTupleFunctionCG;
import org.deeplearning4j.spark.impl.common.misc.ScoreReport;
import org.deeplearning4j.spark.impl.common.misc.UpdaterFromGradientTupleFunctionCG;
import org.deeplearning4j.spark.impl.common.misc.UpdaterFromTupleFunctionCG;
import org.deeplearning4j.spark.impl.common.updater.UpdaterAggregatorCombinerCG;
import org.deeplearning4j.spark.impl.common.updater.UpdaterElementCombinerCG;
import org.deeplearning4j.spark.impl.computationgraph.IterativeReduceFlatMapCG;
import org.deeplearning4j.spark.impl.computationgraph.ScoreFlatMapFunctionCGDataSet;
import org.deeplearning4j.spark.impl.computationgraph.ScoreFlatMapFunctionCGMultiDataSet;
import org.deeplearning4j.spark.impl.computationgraph.dataset.DataSetToMultiDataSetFn;
import org.deeplearning4j.spark.impl.computationgraph.dataset.PairDataSetToMultiDataSetFn;
import org.deeplearning4j.spark.impl.computationgraph.gradientaccum.GradientAccumFlatMapCG;
import org.deeplearning4j.spark.impl.computationgraph.scoring.ScoreExamplesFunction;
import org.deeplearning4j.spark.impl.computationgraph.scoring.ScoreExamplesWithKeyFunction;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple3;

public class SparkComputationGraph
implements Serializable {
    public static final int DEFAULT_EVAL_SCORE_BATCH_SIZE = 50;
    private transient JavaSparkContext sc;
    private ComputationGraphConfiguration conf;
    private ComputationGraph network;
    private Broadcast<INDArray> params;
    private Broadcast<ComputationGraphUpdater> updater;
    private boolean averageEachIteration = false;
    private boolean initDone = false;
    public static final String AVERAGE_EACH_ITERATION = "org.deeplearning4j.spark.iteration.average";
    public static final String ACCUM_GRADIENT = "org.deeplearning4j.spark.iteration.accumgrad";
    public static final String DIVIDE_ACCUM_GRADIENT = "org.deeplearning4j.spark.iteration.dividegrad";
    private double lastScore;
    private static final Logger log = LoggerFactory.getLogger(SparkComputationGraph.class);
    private transient AtomicInteger iterationsCount = new AtomicInteger(0);
    private List<IterationListener> listeners = new ArrayList<IterationListener>();

    public SparkComputationGraph(SparkContext sparkContext, ComputationGraph network) {
        this(new JavaSparkContext(sparkContext), network);
    }

    public SparkComputationGraph(JavaSparkContext javaSparkContext, ComputationGraph network) {
        this.sc = javaSparkContext;
        this.conf = network.getConfiguration().clone();
        this.network = network;
        this.network.init();
        this.updater = this.sc.broadcast((Object)network.getUpdater());
        this.averageEachIteration = this.sc.getConf().getBoolean(AVERAGE_EACH_ITERATION, false);
    }

    public SparkComputationGraph(SparkContext sparkContext, ComputationGraphConfiguration conf) {
        this(new JavaSparkContext(sparkContext), conf);
    }

    public SparkComputationGraph(JavaSparkContext sparkContext, ComputationGraphConfiguration conf) {
        this.sc = sparkContext;
        this.conf = conf.clone();
        this.network = new ComputationGraph(conf);
        this.network.init();
        this.averageEachIteration = sparkContext.sc().conf().getBoolean(AVERAGE_EACH_ITERATION, false);
        this.updater = this.sc.broadcast((Object)this.network.getUpdater());
    }

    public ComputationGraph fit(String path, int labelIndex, RecordReader recordReader, int examplesPerFit, int totalExamples, int numPartitions) {
        if (this.network.getNumInputArrays() != 1 || this.network.getNumOutputArrays() != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph with multiple inputs/outputs from text file + record reader");
        }
        JavaRDD points = this.loadFromTextFile(path, labelIndex, recordReader).map((Function)new DataSetToMultiDataSetFn());
        return this.fitMultiDataSet((JavaRDD<MultiDataSet>)points, examplesPerFit, totalExamples, numPartitions);
    }

    private JavaRDD<DataSet> loadFromTextFile(String path, int labelIndex, RecordReader recordReader) {
        JavaRDD lines = this.sc.textFile(path);
        int nOut = ((FeedForwardLayer)this.network.getOutputLayer(0)).getNOut();
        return lines.map((Function)new RecordReaderFunction(recordReader, labelIndex, nOut));
    }

    public ComputationGraph getNetwork() {
        return this.network;
    }

    public void setNetwork(ComputationGraph network) {
        this.network = network;
    }

    public ComputationGraph fitMultiDataSet(JavaRDD<MultiDataSet> rdd, int examplesPerFit, int totalExamples, int numPartitions) {
        int nSplits = examplesPerFit == Integer.MAX_VALUE || examplesPerFit >= totalExamples ? 1 : (totalExamples % examplesPerFit == 0 ? totalExamples / examplesPerFit : totalExamples / examplesPerFit + 1);
        if (nSplits == 1) {
            this.fitDataSet(rdd);
        } else {
            double[] splitWeights = new double[nSplits];
            for (int i = 0; i < nSplits; ++i) {
                splitWeights[i] = 1.0 / (double)nSplits;
            }
            JavaRDD[] subsets = rdd.randomSplit(splitWeights);
            for (int i = 0; i < subsets.length; ++i) {
                log.info("Initiating distributed training of subset {} of {}", (Object)(i + 1), (Object)subsets.length);
                JavaRDD next = subsets[i].repartition(numPartitions);
                this.fitDataSet((JavaRDD<MultiDataSet>)next);
            }
        }
        return this.network;
    }

    public ComputationGraph fitDataSet(JavaRDD<DataSet> rdd, int examplesPerFit, int totalExamples, int numPartitions) {
        if (this.network.getNumInputArrays() != 1 || this.network.getNumOutputArrays() != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph with multiple inputs/outputs from DataSet");
        }
        JavaRDD mds = rdd.map((Function)new DataSetToMultiDataSetFn());
        return this.fitMultiDataSet((JavaRDD<MultiDataSet>)mds, examplesPerFit, totalExamples, numPartitions);
    }

    public ComputationGraph fitDataSet(JavaRDD<MultiDataSet> rdd) {
        int iterations = this.network.getLayer(0).conf().getNumIterations();
        log.info("Running distributed training: (averaging each iteration = " + this.averageEachIteration + "), (iterations = " + iterations + "), (num partions = " + rdd.partitions().size() + ")");
        if (!this.averageEachIteration) {
            this.runIteration(rdd);
        } else {
            for (GraphVertex gv : this.conf.getVertices().values()) {
                if (!(gv instanceof LayerVertex)) continue;
                ((LayerVertex)gv).getLayerConf().setNumIterations(1);
            }
            for (int i = 0; i < iterations; ++i) {
                this.runIteration(rdd);
            }
            if (iterations > 1) {
                for (GraphVertex gv : this.conf.getVertices().values()) {
                    if (!(gv instanceof LayerVertex)) continue;
                    ((LayerVertex)gv).getLayerConf().setNumIterations(iterations);
                }
            }
        }
        return this.network;
    }

    protected void runIteration(JavaRDD<MultiDataSet> rdd) {
        log.info("Broadcasting initial parameters of length " + this.network.numParams(false));
        int maxRep = 0;
        long maxSm = 0L;
        INDArray valToBroadcast = this.network.params(false);
        this.params = this.sc.broadcast((Object)valToBroadcast);
        ComputationGraphUpdater updater = this.network.getUpdater();
        if (updater == null) {
            this.network.setUpdater(new ComputationGraphUpdater(this.network));
            log.warn("Unable to propagate null updater");
            updater = this.network.getUpdater();
        }
        this.updater = this.sc.broadcast((Object)updater);
        int paramsLength = this.network.numParams(true);
        boolean accumGrad = this.sc.getConf().getBoolean(ACCUM_GRADIENT, false);
        if (accumGrad) {
            JavaRDD results = rdd.mapPartitions((FlatMapFunction)new GradientAccumFlatMapCG(this.conf.toJson(), this.params, this.updater), true).cache();
            JavaRDD resultsGradient = results.map((Function)new GradientFromTupleFunctionCG());
            log.info("Ran iterative reduce... averaging gradients now.");
            GradientAdder a = new GradientAdder(paramsLength);
            resultsGradient.foreach((VoidFunction)a);
            INDArray accumulatedGradient = (INDArray)a.getAccumulator().value();
            boolean divideGrad = this.sc.getConf().getBoolean(DIVIDE_ACCUM_GRADIENT, false);
            if (divideGrad) {
                maxRep = results.partitions().size();
                accumulatedGradient.divi((Number)maxRep);
            }
            log.info("Accumulated parameters");
            log.info("Summed gradients.");
            this.network.setParams(this.network.params(false).addi(accumulatedGradient));
            log.info("Set parameters");
            log.info("Processing updaters");
            JavaRDD resultsUpdater = results.map((Function)new UpdaterFromGradientTupleFunctionCG());
            JavaDoubleRDD scores = results.mapToDouble((DoubleFunction)new ScoreMappingG());
            if (!this.initDone) {
                JavaDoubleRDD sm = results.mapToDouble((DoubleFunction)new SMappingG());
                maxSm = sm.mean().longValue();
            }
            this.lastScore = scores.mean();
            ComputationGraphUpdater.Aggregator aggregator = (ComputationGraphUpdater.Aggregator)resultsUpdater.aggregate(null, (Function2)new UpdaterElementCombinerCG(), (Function2)new UpdaterAggregatorCombinerCG());
            ComputationGraphUpdater combinedUpdater = aggregator.getUpdater();
            this.network.setUpdater(combinedUpdater);
            log.info("Set updater");
        } else {
            JavaRDD results = rdd.mapPartitions((FlatMapFunction)new IterativeReduceFlatMapCG(this.conf.toJson(), this.params, this.updater), true).cache();
            JavaRDD resultsParams = results.map((Function)new INDArrayFromTupleFunctionCG());
            log.info("Running iterative reduce and averaging parameters");
            Adder a = new Adder(paramsLength, (Accumulator<Integer>)this.sc.accumulator(0));
            resultsParams.foreach((VoidFunction)a);
            INDArray newParams = (INDArray)a.getAccumulator().value();
            maxRep = (Integer)a.getCounter().value();
            newParams.divi((Number)maxRep);
            this.network.setParams(newParams);
            log.info("Accumulated and set parameters");
            JavaRDD resultsUpdater = results.map((Function)new UpdaterFromTupleFunctionCG());
            JavaDoubleRDD scores = results.mapToDouble((DoubleFunction)new ScoreMapping());
            if (!this.initDone) {
                JavaDoubleRDD sm = results.mapToDouble((DoubleFunction)new SMapping());
                maxSm = sm.mean().longValue();
            }
            this.lastScore = scores.mean();
            ComputationGraphUpdater.Aggregator aggregator = (ComputationGraphUpdater.Aggregator)resultsUpdater.aggregate(null, (Function2)new UpdaterElementCombinerCG(), (Function2)new UpdaterAggregatorCombinerCG());
            ComputationGraphUpdater combinedUpdater = aggregator.getUpdater();
            this.network.setUpdater(combinedUpdater);
            log.info("Processed and set updater");
        }
        if (this.listeners.size() > 0) {
            this.network.setScore(this.lastScore);
            this.invokeListeners(this.network, this.iterationsCount.incrementAndGet());
        }
        if (!this.initDone) {
            this.initDone = true;
            this.update(maxRep, maxSm);
        }
    }

    public void setListeners(@NonNull Collection<IterationListener> listeners) {
        if (listeners == null) {
            throw new NullPointerException("listeners");
        }
        this.listeners.clear();
        this.listeners.addAll(listeners);
    }

    protected void invokeListeners(ComputationGraph network, int iteration) {
        for (IterationListener listener : this.listeners) {
            try {
                listener.iterationDone((Model)network, iteration);
            }
            catch (Exception e) {
                log.error("Exception caught at IterationListener invocation" + e.getMessage());
                e.printStackTrace();
            }
        }
    }

    public double getScore() {
        return this.lastScore;
    }

    public double calculateScoreDataSet(JavaRDD<DataSet> data, boolean average) {
        long n = data.count();
        JavaRDD scores = data.mapPartitions((FlatMapFunction)new ScoreFlatMapFunctionCGDataSet(this.conf.toJson(), (Broadcast<INDArray>)this.sc.broadcast((Object)this.network.params(false))));
        List scoresList = scores.collect();
        double sum = 0.0;
        for (Double d : scoresList) {
            sum += d.doubleValue();
        }
        if (average) {
            return sum / (double)n;
        }
        return sum;
    }

    public double calculateScore(JavaRDD<MultiDataSet> data, boolean average) {
        long n = data.count();
        JavaRDD scores = data.mapPartitions((FlatMapFunction)new ScoreFlatMapFunctionCGMultiDataSet(this.conf.toJson(), (Broadcast<INDArray>)this.sc.broadcast((Object)this.network.params(false))));
        List scoresList = scores.collect();
        double sum = 0.0;
        for (Double d : scoresList) {
            sum += d.doubleValue();
        }
        if (average) {
            return sum / (double)n;
        }
        return sum;
    }

    public JavaDoubleRDD scoreExamplesDataSet(JavaRDD<DataSet> data, boolean includeRegularizationTerms) {
        return this.scoreExamples((JavaRDD<MultiDataSet>)data.map((Function)new DataSetToMultiDataSetFn()), includeRegularizationTerms);
    }

    public JavaDoubleRDD scoreExamplesDataSet(JavaRDD<DataSet> data, boolean includeRegularizationTerms, int batchSize) {
        return this.scoreExamples((JavaRDD<MultiDataSet>)data.map((Function)new DataSetToMultiDataSetFn()), includeRegularizationTerms, batchSize);
    }

    public <K> JavaPairRDD<K, Double> scoreExamplesDataSet(JavaPairRDD<K, DataSet> data, boolean includeRegularizationTerms) {
        return this.scoreExamples(data.mapToPair(new PairDataSetToMultiDataSetFn()), includeRegularizationTerms, 50);
    }

    public <K> JavaPairRDD<K, Double> scoreExamplesDataSet(JavaPairRDD<K, DataSet> data, boolean includeRegularizationTerms, int batchSize) {
        return this.scoreExamples(data.mapToPair(new PairDataSetToMultiDataSetFn()), includeRegularizationTerms, batchSize);
    }

    public JavaDoubleRDD scoreExamples(JavaRDD<MultiDataSet> data, boolean includeRegularizationTerms) {
        return this.scoreExamples(data, includeRegularizationTerms, 50);
    }

    public JavaDoubleRDD scoreExamples(JavaRDD<MultiDataSet> data, boolean includeRegularizationTerms, int batchSize) {
        return data.mapPartitionsToDouble((DoubleFlatMapFunction)new ScoreExamplesFunction((Broadcast<INDArray>)this.sc.broadcast((Object)this.network.params()), (Broadcast<String>)this.sc.broadcast((Object)this.conf.toJson()), includeRegularizationTerms, batchSize));
    }

    public <K> JavaPairRDD<K, Double> scoreExamples(JavaPairRDD<K, MultiDataSet> data, boolean includeRegularizationTerms) {
        return this.scoreExamples(data, includeRegularizationTerms, 50);
    }

    private void update(int mr, long mg) {
        Environment env = EnvironmentUtils.buildEnvironment();
        env.setNumCores(mr);
        env.setAvailableMemory(mg);
        Task task = ModelSerializer.taskByModel((Model)this.network);
        Heartbeat.getInstance().reportEvent(Event.SPARK, env, task);
    }

    public <K> JavaPairRDD<K, Double> scoreExamples(JavaPairRDD<K, MultiDataSet> data, boolean includeRegularizationTerms, int batchSize) {
        return data.mapPartitionsToPair(new ScoreExamplesWithKeyFunction((Broadcast<INDArray>)this.sc.broadcast((Object)this.network.params()), (Broadcast<String>)this.sc.broadcast((Object)this.conf.toJson()), includeRegularizationTerms, batchSize));
    }

    private static class SMappingG
    implements DoubleFunction<Tuple3<Gradient, ComputationGraphUpdater, ScoreReport>> {
        private SMappingG() {
        }

        public double call(Tuple3<Gradient, ComputationGraphUpdater, ScoreReport> t3) throws Exception {
            return ((ScoreReport)t3._3()).getM();
        }
    }

    private static class SMapping
    implements DoubleFunction<Tuple3<INDArray, ComputationGraphUpdater, ScoreReport>> {
        private SMapping() {
        }

        public double call(Tuple3<INDArray, ComputationGraphUpdater, ScoreReport> t3) throws Exception {
            return ((ScoreReport)t3._3()).getM();
        }
    }

    private static class ScoreMappingG
    implements DoubleFunction<Tuple3<Gradient, ComputationGraphUpdater, ScoreReport>> {
        private ScoreMappingG() {
        }

        public double call(Tuple3<Gradient, ComputationGraphUpdater, ScoreReport> t3) throws Exception {
            return ((ScoreReport)t3._3()).getS();
        }
    }

    private static class ScoreMapping
    implements DoubleFunction<Tuple3<INDArray, ComputationGraphUpdater, ScoreReport>> {
        private ScoreMapping() {
        }

        public double call(Tuple3<INDArray, ComputationGraphUpdater, ScoreReport> t3) throws Exception {
            return ((ScoreReport)t3._3()).getS();
        }
    }
}

