home | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (ParameterAveragingTrainingMaster.java)

This example Java source code file (ParameterAveragingTrainingMaster.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

builder, computationgraph, configured, flatmapfunction, intentional, javardd, multilayernetwork, netbroadcasttuple, override, parameteraveragingtrainingmaster, parameteraveragingtrainingworker, reflection, sparktrainingstats, starting, util, workerconfiguration

The ParameterAveragingTrainingMaster.java Java example source code

package org.deeplearning4j.spark.impl.paramavg;

import lombok.Data;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.StorageLevel;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerMultiDataSetFlatMap;
import org.deeplearning4j.spark.api.worker.NetBroadcastTuple;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingAggregationTuple;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementAddFunction;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementCombineFunction;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMasterStats;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Array;
import java.util.Collection;
import java.util.Iterator;

/**
 * ParameterAveragingTrainingMaster: A {@link TrainingMaster} implementation for spark-only training.
 * This is standard parameter averaging with a configurable averaging period.
 *
 * @author Alex Black
 */
@Data
public class ParameterAveragingTrainingMaster implements TrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker> {

    private static final Logger log = LoggerFactory.getLogger(ParameterAveragingTrainingMaster.class);

    private boolean saveUpdater;
    private int numWorkers;
    private int batchSizePerWorker;
    private int averagingFrequency;
    private int prefetchNumBatches;
    private boolean collectTrainingStats;
    private ParameterAveragingTrainingMasterStats.parameterAveragingTrainingMasterStatsHelper stats;
    private Collection<IterationListener> listeners;
    private int iterationCount = 0;


    private ParameterAveragingTrainingMaster(Builder builder) {
        this.saveUpdater = builder.saveUpdater;
        this.numWorkers = builder.numWorkers;
        this.batchSizePerWorker = builder.batchSizePerWorker;
        this.averagingFrequency = builder.averagingFrequency;
        this.prefetchNumBatches = builder.prefetchNumBatches;
    }

    public ParameterAveragingTrainingMaster(boolean saveUpdater, int numWorkers, int batchSizePerWorker, int averagingFrequency, int prefetchNumBatches) {
        this(saveUpdater, numWorkers, batchSizePerWorker, averagingFrequency, prefetchNumBatches, false);
    }

    public ParameterAveragingTrainingMaster(boolean saveUpdater, int numWorkers, int batchSizePerWorker, int averagingFrequency, int prefetchNumBatches, boolean collectTrainingStats) {
        this.saveUpdater = saveUpdater;
        this.numWorkers = numWorkers;
        this.batchSizePerWorker = batchSizePerWorker;
        this.averagingFrequency = averagingFrequency;
        this.prefetchNumBatches = prefetchNumBatches;
        this.collectTrainingStats = collectTrainingStats;
        if (collectTrainingStats)
            stats = new ParameterAveragingTrainingMasterStats.parameterAveragingTrainingMasterStatsHelper();
    }

    @Override
    public ParameterAveragingTrainingWorker getWorkerInstance(SparkDl4jMultiLayer network) {
        NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getLayerWiseConfigurations(),
                network.getNetwork().params(),
                network.getNetwork().getUpdater());

        if (collectTrainingStats) stats.logBroadcastStart();
        Broadcast<NetBroadcastTuple> broadcast = network.getSparkContext().broadcast(tuple);
        if (collectTrainingStats) stats.logBroadcastEnd();

        WorkerConfiguration configuration = new WorkerConfiguration(false, batchSizePerWorker, averagingFrequency, prefetchNumBatches, collectTrainingStats);
        return new ParameterAveragingTrainingWorker(broadcast, saveUpdater, configuration);
    }

    @Override
    public ParameterAveragingTrainingWorker getWorkerInstance(SparkComputationGraph graph) {
        NetBroadcastTuple tuple = new NetBroadcastTuple(graph.getNetwork().getConfiguration(),
                graph.getNetwork().params(),
                graph.getNetwork().getUpdater());

        if (collectTrainingStats) stats.logBroadcastStart();
        Broadcast<NetBroadcastTuple> broadcast = graph.getSparkContext().broadcast(tuple);
        if (collectTrainingStats) stats.logBroadcastEnd();

        WorkerConfiguration configuration = new WorkerConfiguration(true, batchSizePerWorker, averagingFrequency, prefetchNumBatches, collectTrainingStats);
        return new ParameterAveragingTrainingWorker(broadcast, saveUpdater, configuration);
    }

    @Override
    public void executeTraining(SparkDl4jMultiLayer network, JavaRDD<DataSet> trainingData) {
        if (collectTrainingStats) stats.logFitStart();
        //For "vanilla" parameter averaging training, we need to split the full data set into batches of size N, such that we can process the specified
        // number of minibatches between averagings
        //But to do that, wee need to know: (a) the number of examples, and (b) the number of workers
        trainingData.persist(StorageLevel.MEMORY_ONLY());

        //TODO: The assumption here is that each DataSet represents a single example. But this may not always be the case
        long totalCount = trainingData.count();
        int examplesPerSplit = numWorkers * batchSizePerWorker * averagingFrequency;

        JavaRDD<DataSet>[] splits;
        if (collectTrainingStats) stats.logSplitStart();
        if (totalCount <= examplesPerSplit) {
            splits = (JavaRDD<DataSet>[]) Array.newInstance(JavaRDD.class, 1);
            splits[0] = trainingData;
        } else {
            int numSplits = (int) (totalCount / examplesPerSplit); //Intentional round down
            double[] weights = new double[numSplits];
            for (int i = 0; i < weights.length; i++) weights[i] = 1.0 / numSplits;
            splits = trainingData.randomSplit(weights);
        }
        if (collectTrainingStats) stats.logSplitEnd();

        int splitNum = 1;
        for (JavaRDD<DataSet> split : splits) {
            log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, dataSetTotalExamples={}. Configured for {} executors",
                    splitNum, splits.length, batchSizePerWorker, averagingFrequency, totalCount, numWorkers);

            FlatMapFunction<Iterator function = new ExecuteWorkerFlatMap<>(getWorkerInstance(network));
            JavaRDD<ParameterAveragingTrainingResult> result = split.mapPartitions(function);
            processResults(network, null, result, splitNum, splits.length);

            splitNum++;
        }

        if (collectTrainingStats) stats.logFitEnd();
    }

    @Override
    public void executeTraining(SparkComputationGraph graph, JavaRDD<DataSet> trainingData) {
        JavaRDD<MultiDataSet> mdsTrainingData = trainingData.map(new DataSetToMultiDataSetFn());

        executeTrainingMDS(graph, mdsTrainingData);
    }

    @Override
    public void executeTrainingMDS(SparkComputationGraph graph, JavaRDD<MultiDataSet> trainingData) {
        if (collectTrainingStats) stats.logFitStart();
        //For "vanilla" parameter averaging training, we need to split the full data set into batches of size N, such that we can process the specified
        // number of minibatches between averagings
        //But to do that, wee need to know: (a) the number of examples, and (b) the number of workers
        trainingData.persist(StorageLevel.MEMORY_ONLY());

        //TODO: The assumption here is that each DataSet represents a single example. But this may not always be the case
        long totalCount = trainingData.count();
        int examplesPerSplit = numWorkers * batchSizePerWorker * averagingFrequency;

        JavaRDD<MultiDataSet>[] splits = randomSplit((int) totalCount, examplesPerSplit, trainingData);


        int splitNum = 1;
        for (JavaRDD<MultiDataSet> split : splits) {
            log.info("Starting graph training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, dataSetTotalExamples={}. Configured for {} executors",
                    splitNum, splits.length, batchSizePerWorker, averagingFrequency, totalCount, numWorkers);

            FlatMapFunction<Iterator function = new ExecuteWorkerMultiDataSetFlatMap<>(getWorkerInstance(graph));
            JavaRDD<ParameterAveragingTrainingResult> result = split.mapPartitions(function);
            processResults(null, graph, result, splitNum, splits.length);

            splitNum++;
        }

        if (collectTrainingStats) stats.logFitEnd();
    }

    private <T> JavaRDD[] randomSplit(int totalCount, int examplesPerSplit, JavaRDD data) {
        JavaRDD<T>[] splits;
        if (collectTrainingStats) stats.logSplitStart();
        if (totalCount <= examplesPerSplit) {
            splits = (JavaRDD<T>[]) Array.newInstance(JavaRDD.class, 1);
            splits[0] = data;
        } else {
            int numSplits = totalCount / examplesPerSplit; //Intentional round down
            double[] weights = new double[numSplits];
            for (int i = 0; i < weights.length; i++) weights[i] = 1.0 / numSplits;
            splits = data.randomSplit(weights);
        }
        if (collectTrainingStats) stats.logSplitEnd();
        return splits;
    }


    @Override
    public void setCollectTrainingStats(boolean collectTrainingStats) {
        this.collectTrainingStats = collectTrainingStats;
        if (collectTrainingStats) {
            if (this.stats == null)
                this.stats = new ParameterAveragingTrainingMasterStats.parameterAveragingTrainingMasterStatsHelper();
        } else {
            this.stats = null;
        }
    }

    @Override
    public boolean getIsCollectTrainingStats() {
        return collectTrainingStats;
    }

    @Override
    public SparkTrainingStats getTrainingStats() {
        if (stats != null) return stats.build();
        return null;
    }


    private void processResults(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<ParameterAveragingTrainingResult> results, int splitNum, int totalSplits) {
        //Need to do parameter averaging, and where necessary also do averaging of the updaters

        //Let's do all of this in ONE step, such that we don't have extra synchronization costs

        if (collectTrainingStats) stats.logAggregateStartTime();
        ParameterAveragingAggregationTuple tuple = results.aggregate(null,
                new ParameterAveragingElementAddFunction(),
                new ParameterAveragingElementCombineFunction());
        INDArray params = tuple.getParametersSum();
        int aggCount = tuple.getAggregationsCount();
        SparkTrainingStats aggregatedStats = tuple.getSparkTrainingStats();
        if (collectTrainingStats) stats.logAggregationEndTime();


        if (collectTrainingStats) stats.logProcessParamsUpdaterStart();
        params.divi(aggCount);
        if (network != null) {
            MultiLayerNetwork net = network.getNetwork();
            UpdaterAggregator updaterAg = tuple.getUpdaterAggregator();
            Updater updater = (updaterAg != null ? updaterAg.getUpdater() : null);
            net.setParameters(params);
            net.setUpdater(updater);

            network.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
        } else {
            ComputationGraph g = graph.getNetwork();
            ComputationGraphUpdater.Aggregator updaterAg = tuple.getUpdaterAggregatorGraph();
            ComputationGraphUpdater updater = (updaterAg != null ? updaterAg.getUpdater() : null);
            g.setParams(params);
            g.setUpdater(updater);

            graph.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
        }

        if (collectTrainingStats) {
            stats.logProcessParamsUpdaterEnd();
            stats.addWorkerStats(aggregatedStats);
        }

        log.info("Completed training of split {} of {}", splitNum, totalSplits);

        if(listeners != null){
            if(network != null){
                MultiLayerNetwork net = network.getNetwork();
                net.setScore(network.getScore());
                for(IterationListener il : listeners){
                    il.iterationDone(net, iterationCount);
                }
            } else {
                ComputationGraph g = graph.getNetwork();
                g.setScore(graph.getScore());
                for(IterationListener il : listeners){
                    il.iterationDone(g, iterationCount);
                }
            }
        }

        iterationCount++;
    }


    public static class Builder {

        private boolean saveUpdater;
        private int numWorkers;
        private int batchSizePerWorker = 16;
        private int averagingFrequency = 5;
        private int prefetchNumBatches = 0;

        /**
         * Create a builder, where the following number of workers (Spark executors) are used.
         * Note: this should match the
         *
         * @param numWorkers Number of Spark executors in the cluster
         */
        public Builder(int numWorkers) {
            this.numWorkers = numWorkers;
        }

        /**
         * Batch size per worker
         *
         * @param batchSizePerWorker Size of each minibatch to use for each worker
         * @return
         */
        public Builder batchSizePerWorker(int batchSizePerWorker) {
            this.batchSizePerWorker = batchSizePerWorker;
            return this;
        }

        /**
         * Frequency with which to average worker parameters.<br>
         * <b>Note: Too high or too low can be bad for different reasons.
* - Too low (such as 1) can result in a lot of network traffic<br> * - Too high (>> 20 or so) can result in accuracy issues or problems with network convergence * * @param averagingFrequency Frequency (in number of minibatches of size 'batchSizePerWorker') to average parameters */ public Builder averagingFrequency(int averagingFrequency) { if (averagingFrequency <= 0) throw new IllegalArgumentException("Ivalid input: averaging frequency must be >= 1"); this.averagingFrequency = averagingFrequency; return this; } /** * Set the number of minibatches to asynchronously prefetch in the worker. * <p> * Default: 0 (no prefetching) * * @param prefetchNumBatches Number of minibatches (DataSets of size batchSizePerWorker) to fetch */ public Builder workerPrefetchNumBatches(int prefetchNumBatches) { this.prefetchNumBatches = prefetchNumBatches; return this; } /** * Set whether the updater (i.e., historical state for momentum, adagrad, etc should be saved). * <b>NOTE: This can double (or more) the amount of network traffic in each direction, but might * improve network training performance (and can be more stable for certain updaters such as adagrad).<br> * <p> * This is <b>enabled by default. * * @param saveUpdater If true: retain the updater state (default). If false, don't retain (updaters will be * reinitalized in each worker after averaging). */ public Builder saveUpdater(boolean saveUpdater) { this.saveUpdater = saveUpdater; return this; } public ParameterAveragingTrainingMaster build() { return new ParameterAveragingTrainingMaster(this); } } }

Other Java examples (source code examples)

Here is a short list of links related to this Java ParameterAveragingTrainingMaster.java source code file:



my book on functional programming

 

new blog posts

 

Copyright 1998-2019 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.