package org.neuroph.contrib.model.modelselection;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.logging.Level;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.eval.KFoldCrossValidation;
import org.neuroph.eval.classification.ClassificationMetrics;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/neuroph/contrib/model/modelselection/MultilayerPerceptronOptimazer.class */
public class MultilayerPerceptronOptimazer<T extends BackPropagation> {
    private static Logger LOG = LoggerFactory.getLogger(MultilayerPerceptronOptimazer.class);
    private List<Integer> optimalArchitecure;
    private NeuralNetwork<BackPropagation> optimalClassifier;
    private ClassificationMetrics optimalResult;
    private KFoldCrossValidation errorEstimationMethod;
    private BackPropagation learningRule;
    private Set<List<Integer>> allArchitectures = new HashSet();
    private int maxLayers = 1;
    private int minNeuronsPerLayer = 1;
    private int maxNeuronsPerLayer = 30;
    private int neuronIncrement = 1;

    /* loaded from: input_file:org/neuroph/contrib/model/modelselection/MultilayerPerceptronOptimazer$LearningListener.class */
    static class LearningListener implements LearningEventListener {
        private double[] foldErrors;
        private int foldSize;

        public LearningListener(int i, int i2) {
            this.foldSize = i;
            this.foldErrors = new double[i2];
        }

        public void handleLearningEvent(LearningEvent learningEvent) {
            BackPropagation backPropagation = (BackPropagation) learningEvent.getSource();
            double[] dArr = this.foldErrors;
            int currentIteration = backPropagation.getCurrentIteration() - 1;
            dArr[currentIteration] = dArr[currentIteration] + (backPropagation.getTotalNetworkError() / this.foldSize);
        }
    }

    public MultilayerPerceptronOptimazer withMaxLayers(int i) {
        this.maxLayers = i;
        return this;
    }

    public MultilayerPerceptronOptimazer withNeuronIncrement(int i) {
        this.neuronIncrement = i;
        return this;
    }

    public MultilayerPerceptronOptimazer withMaxNeurons(int i) {
        this.maxNeuronsPerLayer = i;
        return this;
    }

    public MultilayerPerceptronOptimazer withMinNeurons(int i) {
        this.minNeuronsPerLayer = i;
        return this;
    }

    public MultilayerPerceptronOptimazer withErrorEstimationMethod(KFoldCrossValidation kFoldCrossValidation) {
        this.errorEstimationMethod = kFoldCrossValidation;
        return this;
    }

    public MultilayerPerceptronOptimazer withLearningRule(BackPropagation backPropagation) {
        this.learningRule = backPropagation;
        return this;
    }

    public NeuralNetwork createOptimalModel(DataSet dataSet) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(Integer.valueOf(this.minNeuronsPerLayer));
        findArchitectures(1, this.minNeuronsPerLayer, arrayList);
        LOG.info("Total [{}] different network topologies found", Integer.valueOf(this.allArchitectures.size()));
        for (List<Integer> list : this.allArchitectures) {
            try {
                list.add(0, Integer.valueOf(dataSet.getInputSize()));
                list.add(Integer.valueOf(dataSet.getOutputSize()));
                LOG.info("Architecture: [{}]", list);
                MultiLayerPerceptron multiLayerPerceptron = new MultiLayerPerceptron(list);
                this.learningRule.addListener(new LearningListener(10, this.learningRule.getMaxIterations()));
                multiLayerPerceptron.setLearningRule(this.learningRule);
                this.errorEstimationMethod = new KFoldCrossValidation(multiLayerPerceptron, dataSet, 10);
                this.errorEstimationMethod.run();
                ClassificationMetrics[] classificationMetricsArr = null;
                if (this.optimalResult == null || this.optimalResult.getFMeasure() < classificationMetricsArr[0].getFMeasure()) {
                    LOG.info("Architecture [{}] became optimal architecture  with metrics {}", list, (Object) null);
                    this.optimalResult = classificationMetricsArr[0];
                    this.optimalClassifier = multiLayerPerceptron;
                    this.optimalArchitecure = list;
                }
                LOG.info("#################################################################");
            } catch (InterruptedException e) {
                java.util.logging.Logger.getLogger(MultilayerPerceptronOptimazer.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            } catch (ExecutionException e2) {
                java.util.logging.Logger.getLogger(MultilayerPerceptronOptimazer.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
            }
        }
        LOG.info("Optimal Architecture: {}", this.optimalArchitecure);
        return this.optimalClassifier;
    }

    private void findArchitectures(int i, int i2, List<Integer> list) {
        this.allArchitectures.add(new ArrayList(list));
        if (i2 + this.neuronIncrement <= this.maxNeuronsPerLayer) {
            int size = list.size() - 1;
            ArrayList arrayList = new ArrayList(list);
            arrayList.set(size, Integer.valueOf(i2 + this.neuronIncrement));
            findArchitectures(i, i2 + this.neuronIncrement, arrayList);
        }
        if (i + 1 <= this.maxLayers) {
            ArrayList arrayList2 = new ArrayList(list);
            arrayList2.add(1);
            findArchitectures(i + 1, this.minNeuronsPerLayer, arrayList2);
        }
    }
}
