alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (TrainingWorker.java)

This example Java source code file (TrainingWorker.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

computationgraph, multilayernetwork, pair, serializable, sparktrainingstats, trainingresult, trainingworker, workerconfiguration

The TrainingWorker.java Java example source code

package org.deeplearning4j.spark.api;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;

import java.io.Serializable;

/**
 * TrainingWorker is a small serializable class that can be passed (in serialized form) to each Spark executor
 * for actually conducting training. The results are then passed back to the {@link TrainingMaster} for processing.<br>
 * <p>
 * TrainingWorker implementations provide a layer of abstraction for network learning tha should allow for more flexibility/
 * control over how learning is conducted (including for example asynchronous communication)
 *
 * @author Alex Black
 */
public interface TrainingWorker<R extends TrainingResult> extends Serializable {

    /**
     * Get the initial model when training a MultiLayerNetwork/SparkDl4jMultiLayer
     *
     * @return Initial model for this worker
     */
    MultiLayerNetwork getInitialModel();

    /**
     * Get the initial model when training a ComputationGraph/SparkComputationGraph
     *
     * @return Initial model for this worker
     */
    ComputationGraph getInitialModelGraph();

    /**
     * Process (fit) a minibatch for a MultiLayerNetwork
     *
     * @param dataSet Data set to train on
     * @param network Network to train
     * @param isLast  If true: last data set currently available. If false: more data sets will be processed for this executor
     * @return Null, or a training result if training should be terminated immediately.
     */
    R processMinibatch(DataSet dataSet, MultiLayerNetwork network, boolean isLast);

    /**
     * Process (fit) a minibatch for a ComputationGraph
     *
     * @param dataSet Data set to train on
     * @param graph   Network to train
     * @param isLast  If true: last data set currently available. If false: more data sets will be processed for this executor
     * @return Null, or a training result if training should be terminated immediately.
     */
    R processMinibatch(DataSet dataSet, ComputationGraph graph, boolean isLast);

    /**
     * Process (fit) a minibatch for a ComputationGraph using a MultiDataSet
     *
     * @param dataSet Data set to train on
     * @param graph   Network to train
     * @param isLast  If true: last data set currently available. If false: more data sets will be processed for this executor
     * @return Null, or a training result if training should be terminated immediately.
     */
    R processMinibatch(MultiDataSet dataSet, ComputationGraph graph, boolean isLast);

    /**
     * As per {@link #processMinibatch(DataSet, MultiLayerNetwork, boolean)} but used when {@link SparkTrainingStats} are being collecte
     */
    Pair<R, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork network, boolean isLast);

    /**
     * As per {@link #processMinibatch(DataSet, ComputationGraph, boolean)} but used when {@link SparkTrainingStats} are being collecte
     */
    Pair<R, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, ComputationGraph graph, boolean isLast);

    /**
     * As per {@link #processMinibatch(MultiDataSet, ComputationGraph, boolean)} but used when {@link SparkTrainingStats} are being collecte
     */
    Pair<R, SparkTrainingStats> processMinibatchWithStats(MultiDataSet dataSet, ComputationGraph graph, boolean isLast);

    /**
     * Get the final result to be returned to the driver
     *
     * @param network Current state of the network
     * @return Result to return to the driver
     */
    R getFinalResult(MultiLayerNetwork network);

    /**
     * Get the final result to be returned to the driver
     *
     * @param graph Current state of the network
     * @return Result to return to the driver
     */
    R getFinalResult(ComputationGraph graph);

    /**
     * Get the final result to be returned to the driver, if no data was available for this executor
     *
     * @return Result to return to the driver
     */
    R getFinalResultNoData();

    /**
     * As per {@link #getFinalResultNoData()} but used when {@link SparkTrainingStats} are being collected
     */
    Pair<R, SparkTrainingStats> getFinalResultNoDataWithStats();

    /**
     * As per {@link #getFinalResult(MultiLayerNetwork)} but used when {@link SparkTrainingStats} are being collected
     */
    Pair<R, SparkTrainingStats> getFinalResultWithStats(MultiLayerNetwork network);

    /**
     * As per {@link #getFinalResult(ComputationGraph)} but used when {@link SparkTrainingStats} are being collected
     */
    Pair<R, SparkTrainingStats> getFinalResultWithStats(ComputationGraph graph);

    /**
     * Get the {@link WorkerConfiguration} that contains information such as minibatch sizes, etc
     *
     * @return Worker configuration
     */
    WorkerConfiguration getDataConfiguration();
}

Other Java examples (source code examples)

Here is a short list of links related to this Java TrainingWorker.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.