alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (ParameterAveragingTrainingWorker.java)

This example Java source code file (ParameterAveragingTrainingWorker.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

again, computationgraph, multilayernetwork, netbroadcasttuple, override, pair, parameteraveragingtrainingresult, parameteraveragingtrainingworker, sparktrainingstats, workerconfiguration

The ParameterAveragingTrainingWorker.java Java example source code

package org.deeplearning4j.spark.impl.paramavg;

import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.api.TrainingWorker;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.worker.NetBroadcastTuple;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingWorkerStats;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;

/**
 * ParameterAveragingTrainingWorker implements standard parameter averaging every m iterations.
 *
 * @author Alex Black
 */
public class ParameterAveragingTrainingWorker implements TrainingWorker<ParameterAveragingTrainingResult> {

    private final Broadcast<NetBroadcastTuple> broadcast;
    private final boolean saveUpdater;
    private final WorkerConfiguration configuration;
    private ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper stats = null;

    public ParameterAveragingTrainingWorker(Broadcast<NetBroadcastTuple> broadcast, boolean saveUpdater, WorkerConfiguration configuration) {
        this.broadcast = broadcast;
        this.saveUpdater = saveUpdater;
        this.configuration = configuration;
    }

    @Override
    public MultiLayerNetwork getInitialModel() {
        if(configuration.isCollectTrainingStats()) stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();

        if(configuration.isCollectTrainingStats()) stats.logBroadcastGetValueStart();
        NetBroadcastTuple tuple = broadcast.getValue();
        if(configuration.isCollectTrainingStats()) stats.logBroadcastGetValueEnd();

        MultiLayerNetwork net = new MultiLayerNetwork(tuple.getConfiguration());
        //Can't have shared parameter array across executors for parameter averaging, hence the 'true' for clone parameters array arg
        net.init(tuple.getParameters(), true);

        if(tuple.getUpdater() != null){
            net.setUpdater(tuple.getUpdater().clone()); //Again: can't have shared updaters
        }

        if(configuration.isCollectTrainingStats()) stats.logInitEnd();

        return net;
    }

    @Override
    public ComputationGraph getInitialModelGraph() {
        if(configuration.isCollectTrainingStats()) stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();

        if(configuration.isCollectTrainingStats()) stats.logBroadcastGetValueStart();
        NetBroadcastTuple tuple = broadcast.getValue();
        if(configuration.isCollectTrainingStats()) stats.logBroadcastGetValueEnd();

        ComputationGraph net = new ComputationGraph(tuple.getGraphConfiguration());
        //Can't have shared parameter array across executors for parameter averaging, hence the 'true' for clone parameters array arg
        net.init(tuple.getParameters(), true);

        if(tuple.getGraphUpdater() != null){
            net.setUpdater(tuple.getGraphUpdater().clone()); //Again: can't have shared updaters
        }

        if(configuration.isCollectTrainingStats()) stats.logInitEnd();

        return net;
    }

    @Override
    public ParameterAveragingTrainingResult processMinibatch(DataSet dataSet, MultiLayerNetwork network, boolean isLast) {
        if(configuration.isCollectTrainingStats()) stats.logFitStart();
        network.fit(dataSet);
        if(configuration.isCollectTrainingStats()) stats.logFitEnd();

        if(isLast) return getFinalResult(network);

        return null;
    }

    @Override
    public ParameterAveragingTrainingResult processMinibatch(DataSet dataSet, ComputationGraph graph, boolean isLast){
        return processMinibatch(ComputationGraphUtil.toMultiDataSet(dataSet), graph, isLast);
    }

    @Override
    public ParameterAveragingTrainingResult processMinibatch(MultiDataSet dataSet, ComputationGraph graph, boolean isLast){
        if(configuration.isCollectTrainingStats()) stats.logFitStart();
        graph.fit(dataSet);
        if(configuration.isCollectTrainingStats()) stats.logFitEnd();

        if(isLast) return getFinalResult(graph);

        return null;
    }



    @Override
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork network, boolean isLast) {
        ParameterAveragingTrainingResult result = processMinibatch(dataSet,network,isLast);
        if(result == null) return null;

        SparkTrainingStats statsToReturn = (stats != null ? stats.build() : null);
        return new Pair<>(result, statsToReturn);
    }

    @Override
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, ComputationGraph graph, boolean isLast) {
        return processMinibatchWithStats(ComputationGraphUtil.toMultiDataSet(dataSet), graph, isLast);
    }

    @Override
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> processMinibatchWithStats(MultiDataSet dataSet, ComputationGraph graph, boolean isLast) {
        ParameterAveragingTrainingResult result = processMinibatch(dataSet,graph,isLast);
        if(result == null) return null;

        SparkTrainingStats statsToReturn = (stats != null ? stats.build() : null);
        return new Pair<>(result, statsToReturn);
    }

    @Override
    public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network) {
        //TODO: don't want to use java serialization for updater, in case worker is using cuda and master is using native, etc
        return new ParameterAveragingTrainingResult(network.params(), (saveUpdater ? network.getUpdater() : null), network.score());
    }

    @Override
    public ParameterAveragingTrainingResult getFinalResult(ComputationGraph network) {
        //TODO: don't want to use java serialization for updater, in case worker is using cuda and master is using native, etc
        return new ParameterAveragingTrainingResult(network.params(), (saveUpdater ? network.getUpdater() : null), network.score());
    }

    @Override
    public ParameterAveragingTrainingResult getFinalResultNoData(){
        return new ParameterAveragingTrainingResult(null, null, null, 0.0, null);
    }

    @Override
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> getFinalResultNoDataWithStats(){
        return new Pair<>(new ParameterAveragingTrainingResult(null, null, null, 0.0, null),null);
    }

    @Override
    public Pair<ParameterAveragingTrainingResult,SparkTrainingStats> getFinalResultWithStats(MultiLayerNetwork network) {
        ParameterAveragingTrainingResult result = getFinalResult(network);
        if(result == null) return null;

        SparkTrainingStats statsToReturn = (stats != null ? stats.build() : null);
        return new Pair<>(result,statsToReturn);
    }

    @Override
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> getFinalResultWithStats(ComputationGraph graph){
        ParameterAveragingTrainingResult result = getFinalResult(graph);
        if(result == null) return null;

        SparkTrainingStats statsToReturn = (stats != null ? stats.build() : null);
        return new Pair<>(result,statsToReturn);
    }

    @Override
    public WorkerConfiguration getDataConfiguration() {
        return configuration;
    }



}

Other Java examples (source code examples)

Here is a short list of links related to this Java ParameterAveragingTrainingWorker.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.