/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.multilayer.gradientaccum;

import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GradientAccum
implements Function<org.nd4j.linalg.dataset.DataSet, Gradient> {
    private static final Logger log = LoggerFactory.getLogger(GradientAccum.class);
    private String json;
    private Broadcast<INDArray> params;

    public GradientAccum(Broadcast<INDArray> params, String json) {
        this.params = params;
        this.json = json;
    }

    public Gradient call(org.nd4j.linalg.dataset.DataSet dataSet) throws Exception {
        log.info("Training on " + dataSet.numExamples());
        MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson((String)this.json);
        MultiLayerNetwork network = new MultiLayerNetwork(conf);
        network.init();
        network.setParameters(((INDArray)this.params.value()).dup());
        network.fit((DataSet)dataSet);
        return network.gradient();
    }
}

