/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.earlystopping;

import java.io.IOException;
import java.lang.reflect.Array;
import java.util.LinkedHashMap;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer;
import org.deeplearning4j.nn.api.Model;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseSparkEarlyStoppingTrainer<T extends Model>
implements IEarlyStoppingTrainer<T> {
    private static Logger log = LoggerFactory.getLogger(BaseSparkEarlyStoppingTrainer.class);
    private SparkContext sc;
    private final EarlyStoppingConfiguration<T> esConfig;
    private T net;
    private final JavaRDD<DataSet> train;
    private final JavaRDD<MultiDataSet> trainMulti;
    protected final int examplesPerFit;
    protected final int totalExamples;
    protected final int numPartitions;
    private EarlyStoppingListener<T> listener;
    private double bestModelScore = Double.MAX_VALUE;
    private int bestModelEpoch = -1;

    protected BaseSparkEarlyStoppingTrainer(SparkContext sc, EarlyStoppingConfiguration<T> esConfig, T net, JavaRDD<DataSet> train, JavaRDD<MultiDataSet> trainMulti, int examplesPerFit, int totalExamples, int numPartitions, EarlyStoppingListener<T> listener) {
        if (!(esConfig.getEpochTerminationConditions() != null && esConfig.getEpochTerminationConditions().size() != 0 || esConfig.getIterationTerminationConditions() != null && esConfig.getIterationTerminationConditions().size() != 0)) {
            throw new IllegalArgumentException("Cannot conduct early stopping without a termination condition (both Iteration and Epoch termination conditions are null/empty)");
        }
        this.sc = sc;
        this.esConfig = esConfig;
        this.net = net;
        this.train = train;
        this.trainMulti = trainMulti;
        this.examplesPerFit = examplesPerFit;
        this.totalExamples = totalExamples;
        this.numPartitions = numPartitions;
        this.listener = listener;
    }

    protected abstract void fit(JavaRDD<DataSet> var1);

    protected abstract void fitMulti(JavaRDD<MultiDataSet> var1);

    protected abstract double getScore();

    public EarlyStoppingResult<T> fit() {
        log.info("Starting early stopping training");
        if (this.esConfig.getScoreCalculator() == null) {
            log.warn("No score calculator provided for early stopping. Score will be reported as 0.0 to epoch termination conditions");
        }
        if (this.esConfig.getIterationTerminationConditions() != null) {
            for (IterationTerminationCondition c : this.esConfig.getIterationTerminationConditions()) {
                c.initialize();
            }
        }
        if (this.esConfig.getEpochTerminationConditions() != null) {
            for (IterationTerminationCondition c : this.esConfig.getEpochTerminationConditions()) {
                c.initialize();
            }
        }
        if (this.listener != null) {
            this.listener.onStart(this.esConfig, this.net);
        }
        LinkedHashMap<Integer, Double> scoreVsEpoch = new LinkedHashMap<Integer, Double>();
        if (this.train != null) {
            this.train.cache();
        } else {
            this.trainMulti.cache();
        }
        int epochCount = 0;
        while (true) {
            int i;
            boolean terminate = false;
            Object terminationReason = null;
            int iterCount = 0;
            int nSplits = this.totalExamples % this.examplesPerFit == 0 ? this.totalExamples / this.examplesPerFit : this.totalExamples / this.examplesPerFit + 1;
            JavaRDD[] subsets = null;
            JavaRDD[] subsetsMulti = null;
            if (this.train != null) {
                if (nSplits == 1) {
                    subsets = (JavaRDD[])Array.newInstance(JavaRDD.class, 1);
                    subsets[0] = this.train;
                } else {
                    double[] splitWeights = new double[nSplits];
                    for (i = 0; i < nSplits; ++i) {
                        splitWeights[i] = 1.0 / (double)nSplits;
                    }
                    subsets = this.train.randomSplit(splitWeights);
                }
            } else if (nSplits == 1) {
                subsetsMulti = (JavaRDD[])Array.newInstance(JavaRDD.class, 1);
                subsetsMulti[0] = this.trainMulti;
            } else {
                double[] splitWeights = new double[nSplits];
                for (int i2 = 0; i2 < nSplits; ++i2) {
                    splitWeights[i2] = 1.0 / (double)nSplits;
                }
                subsetsMulti = this.trainMulti.randomSplit(splitWeights);
            }
            int nSubsets = subsets != null ? subsets.length : subsetsMulti.length;
            for (i = 0; i < nSubsets; ++i) {
                log.info("Initiating distributed training of subset {} of {}", (Object)(i + 1), (Object)nSubsets);
                try {
                    if (subsets != null) {
                        this.fit((JavaRDD<DataSet>)subsets[i]);
                    } else {
                        this.fitMulti((JavaRDD<MultiDataSet>)subsetsMulti[i]);
                    }
                }
                catch (Exception e) {
                    Model bestModel;
                    log.warn("Early stopping training terminated due to exception at epoch {}, iteration {}", new Object[]{epochCount, iterCount, e});
                    try {
                        bestModel = this.esConfig.getModelSaver().getBestModel();
                    }
                    catch (IOException e2) {
                        throw new RuntimeException(e2);
                    }
                    return new EarlyStoppingResult(EarlyStoppingResult.TerminationReason.Error, e.toString(), scoreVsEpoch, this.bestModelEpoch, this.bestModelScore, epochCount, bestModel);
                }
                double lastScore = this.getScore();
                for (IterationTerminationCondition c : this.esConfig.getIterationTerminationConditions()) {
                    if (!c.terminate(lastScore)) continue;
                    terminate = true;
                    terminationReason = c;
                    break;
                }
                if (terminate) break;
                ++iterCount;
            }
            if (terminate) {
                Model bestModel;
                log.info("Hit per iteration epoch termination condition at epoch {}, iteration {}. Reason: {}", new Object[]{epochCount, iterCount, terminationReason});
                if (this.esConfig.isSaveLastModel()) {
                    try {
                        this.esConfig.getModelSaver().saveLatestModel(this.net, 0.0);
                    }
                    catch (IOException e) {
                        throw new RuntimeException("Error saving most recent model", e);
                    }
                }
                try {
                    bestModel = this.esConfig.getModelSaver().getBestModel();
                }
                catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
                EarlyStoppingResult result = new EarlyStoppingResult(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, terminationReason.toString(), scoreVsEpoch, this.bestModelEpoch, this.bestModelScore, epochCount, bestModel);
                if (this.listener != null) {
                    this.listener.onCompletion(result);
                }
                return result;
            }
            log.info("Completed training epoch {}", (Object)epochCount);
            if ((epochCount != 0 || this.esConfig.getEvaluateEveryNEpochs() != 1) && epochCount % this.esConfig.getEvaluateEveryNEpochs() != 0) continue;
            ScoreCalculator sc = this.esConfig.getScoreCalculator();
            double score = sc == null ? 0.0 : this.esConfig.getScoreCalculator().calculateScore(this.net);
            scoreVsEpoch.put(epochCount - 1, score);
            if (sc != null && score < this.bestModelScore) {
                if (this.bestModelEpoch == -1) {
                    log.info("Score at epoch {}: {}", (Object)epochCount, (Object)score);
                } else {
                    log.info("New best model: score = {}, epoch = {} (previous: score = {}, epoch = {})", new Object[]{score, epochCount, this.bestModelScore, this.bestModelEpoch});
                }
                this.bestModelScore = score;
                this.bestModelEpoch = epochCount;
                try {
                    this.esConfig.getModelSaver().saveBestModel(this.net, score);
                }
                catch (IOException e) {
                    throw new RuntimeException("Error saving best model", e);
                }
            }
            if (this.esConfig.isSaveLastModel()) {
                try {
                    this.esConfig.getModelSaver().saveLatestModel(this.net, score);
                }
                catch (IOException e) {
                    throw new RuntimeException("Error saving most recent model", e);
                }
            }
            if (this.listener != null) {
                this.listener.onEpoch(epochCount, score, this.esConfig, this.net);
            }
            boolean epochTerminate = false;
            Object termReason = null;
            for (EpochTerminationCondition c : this.esConfig.getEpochTerminationConditions()) {
                if (!c.terminate(epochCount, score)) continue;
                epochTerminate = true;
                termReason = c;
                break;
            }
            if (epochTerminate) {
                Model bestModel;
                log.info("Hit epoch termination condition at epoch {}. Details: {}", (Object)epochCount, (Object)termReason.toString());
                try {
                    bestModel = this.esConfig.getModelSaver().getBestModel();
                }
                catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
                EarlyStoppingResult result = new EarlyStoppingResult(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, termReason.toString(), scoreVsEpoch, this.bestModelEpoch, this.bestModelScore, epochCount + 1, bestModel);
                if (this.listener != null) {
                    this.listener.onCompletion(result);
                }
                return result;
            }
            ++epochCount;
        }
    }

    public void setListener(EarlyStoppingListener<T> listener) {
        this.listener = listener;
    }
}

