/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.multilayer.scoring;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public class ScoreExamplesWithKeyFunction<K>
implements PairFlatMapFunction<Iterator<Tuple2<K, DataSet>>, K, Double> {
    protected static Logger log = LoggerFactory.getLogger(ScoreExamplesWithKeyFunction.class);
    private final Broadcast<INDArray> params;
    private final Broadcast<String> jsonConfig;
    private final boolean addRegularization;
    private final int batchSize;

    public ScoreExamplesWithKeyFunction(Broadcast<INDArray> params, Broadcast<String> jsonConfig, boolean addRegularizationTerms, int batchSize) {
        this.params = params;
        this.jsonConfig = jsonConfig;
        this.addRegularization = addRegularizationTerms;
        this.batchSize = batchSize;
    }

    public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, DataSet>> iterator) throws Exception {
        if (!iterator.hasNext()) {
            return Collections.emptyList();
        }
        MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String)((String)this.jsonConfig.getValue())));
        network.init();
        INDArray val = (INDArray)this.params.value();
        if (val.length() != network.numParams(false)) {
            throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters");
        }
        network.setParameters(val);
        ArrayList<Tuple2<K, Double>> ret = new ArrayList<Tuple2<K, Double>>();
        ArrayList<DataSet> collect = new ArrayList<DataSet>(this.batchSize);
        ArrayList<Object> collectKey = new ArrayList<Object>(this.batchSize);
        int totalCount = 0;
        while (iterator.hasNext()) {
            int nExamples;
            int n;
            collect.clear();
            collectKey.clear();
            for (nExamples = 0; iterator.hasNext() && nExamples < this.batchSize; nExamples += n) {
                Tuple2<K, DataSet> t2 = iterator.next();
                DataSet ds = (DataSet)t2._2();
                n = ds.numExamples();
                if (n != 1) {
                    throw new IllegalStateException("Cannot score examples with one key per data set if data set contains more than 1 example (numExamples: " + n + ")");
                }
                collect.add(ds);
                collectKey.add(t2._1());
            }
            totalCount += nExamples;
            DataSet data = DataSet.merge(collect, (boolean)false);
            INDArray scores = network.scoreExamples(data, this.addRegularization);
            double[] doubleScores = scores.data().asDouble();
            for (int i = 0; i < doubleScores.length; ++i) {
                ret.add(new Tuple2(collectKey.get(i), (Object)doubleScores[i]));
            }
        }
        if (log.isDebugEnabled()) {
            log.debug("Scored {} examples ", (Object)totalCount);
        }
        return ret;
    }
}

