|
Java example source code file (TrainingWorker.java)
The TrainingWorker.java Java example source codepackage org.deeplearning4j.spark.api; import org.deeplearning4j.berkeley.Pair; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.spark.api.stats.SparkTrainingStats; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; import java.io.Serializable; /** * TrainingWorker is a small serializable class that can be passed (in serialized form) to each Spark executor * for actually conducting training. The results are then passed back to the {@link TrainingMaster} for processing.<br> * <p> * TrainingWorker implementations provide a layer of abstraction for network learning tha should allow for more flexibility/ * control over how learning is conducted (including for example asynchronous communication) * * @author Alex Black */ public interface TrainingWorker<R extends TrainingResult> extends Serializable { /** * Get the initial model when training a MultiLayerNetwork/SparkDl4jMultiLayer * * @return Initial model for this worker */ MultiLayerNetwork getInitialModel(); /** * Get the initial model when training a ComputationGraph/SparkComputationGraph * * @return Initial model for this worker */ ComputationGraph getInitialModelGraph(); /** * Process (fit) a minibatch for a MultiLayerNetwork * * @param dataSet Data set to train on * @param network Network to train * @param isLast If true: last data set currently available. If false: more data sets will be processed for this executor * @return Null, or a training result if training should be terminated immediately. */ R processMinibatch(DataSet dataSet, MultiLayerNetwork network, boolean isLast); /** * Process (fit) a minibatch for a ComputationGraph * * @param dataSet Data set to train on * @param graph Network to train * @param isLast If true: last data set currently available. If false: more data sets will be processed for this executor * @return Null, or a training result if training should be terminated immediately. */ R processMinibatch(DataSet dataSet, ComputationGraph graph, boolean isLast); /** * Process (fit) a minibatch for a ComputationGraph using a MultiDataSet * * @param dataSet Data set to train on * @param graph Network to train * @param isLast If true: last data set currently available. If false: more data sets will be processed for this executor * @return Null, or a training result if training should be terminated immediately. */ R processMinibatch(MultiDataSet dataSet, ComputationGraph graph, boolean isLast); /** * As per {@link #processMinibatch(DataSet, MultiLayerNetwork, boolean)} but used when {@link SparkTrainingStats} are being collecte */ Pair<R, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork network, boolean isLast); /** * As per {@link #processMinibatch(DataSet, ComputationGraph, boolean)} but used when {@link SparkTrainingStats} are being collecte */ Pair<R, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, ComputationGraph graph, boolean isLast); /** * As per {@link #processMinibatch(MultiDataSet, ComputationGraph, boolean)} but used when {@link SparkTrainingStats} are being collecte */ Pair<R, SparkTrainingStats> processMinibatchWithStats(MultiDataSet dataSet, ComputationGraph graph, boolean isLast); /** * Get the final result to be returned to the driver * * @param network Current state of the network * @return Result to return to the driver */ R getFinalResult(MultiLayerNetwork network); /** * Get the final result to be returned to the driver * * @param graph Current state of the network * @return Result to return to the driver */ R getFinalResult(ComputationGraph graph); /** * Get the final result to be returned to the driver, if no data was available for this executor * * @return Result to return to the driver */ R getFinalResultNoData(); /** * As per {@link #getFinalResultNoData()} but used when {@link SparkTrainingStats} are being collected */ Pair<R, SparkTrainingStats> getFinalResultNoDataWithStats(); /** * As per {@link #getFinalResult(MultiLayerNetwork)} but used when {@link SparkTrainingStats} are being collected */ Pair<R, SparkTrainingStats> getFinalResultWithStats(MultiLayerNetwork network); /** * As per {@link #getFinalResult(ComputationGraph)} but used when {@link SparkTrainingStats} are being collected */ Pair<R, SparkTrainingStats> getFinalResultWithStats(ComputationGraph graph); /** * Get the {@link WorkerConfiguration} that contains information such as minibatch sizes, etc * * @return Worker configuration */ WorkerConfiguration getDataConfiguration(); } Other Java examples (source code examples)Here is a short list of links related to this Java TrainingWorker.java source code file: |
... this post is sponsored by my books ... | |
#1 New Release! |
FP Best Seller |
Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.
A percentage of advertising revenue from
pages under the /java/jwarehouse
URI on this website is
paid back to open source projects.