|
Java example source code file (SparkDl4jMultiLayer.java)
The SparkDl4jMultiLayer.java Java example source code/* * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WÏITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */ package org.deeplearning4j.spark.impl.multilayer; import lombok.NonNull; import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.rdd.RDD; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.api.IterationListener; import org.deeplearning4j.spark.api.TrainingMaster; import org.deeplearning4j.spark.api.stats.SparkTrainingStats; import org.deeplearning4j.spark.impl.multilayer.evaluation.EvaluateFlatMapFunction; import org.deeplearning4j.spark.impl.multilayer.evaluation.EvaluationReduceFunction; import org.deeplearning4j.spark.impl.multilayer.scoring.ScoreExamplesFunction; import org.deeplearning4j.spark.impl.multilayer.scoring.ScoreExamplesWithKeyFunction; import org.deeplearning4j.spark.impl.multilayer.scoring.ScoreFlatMapFunction; import org.deeplearning4j.spark.util.MLLibUtil; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.heartbeat.Heartbeat; import org.nd4j.linalg.heartbeat.reports.Environment; import org.nd4j.linalg.heartbeat.reports.Event; import org.nd4j.linalg.heartbeat.reports.Task; import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.List; /** * Master class for spark * * @author Adam Gibson, Alex Black */ public class SparkDl4jMultiLayer implements Serializable { private static final Logger log = LoggerFactory.getLogger(SparkDl4jMultiLayer.class); public static final int DEFAULT_EVAL_SCORE_BATCH_SIZE = 64; private transient JavaSparkContext sc; private TrainingMaster trainingMaster; private MultiLayerConfiguration conf; private MultiLayerNetwork network; private double lastScore; private List<IterationListener> listeners = new ArrayList<>(); /** * Instantiate a multi layer spark instance * with the given context and network. * This is the prediction constructor * @param sparkContext the spark context to use * @param network the network to use */ public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerNetwork network, TrainingMaster trainingMaster) { this(new JavaSparkContext(sparkContext),network, trainingMaster); } /** * Training constructor. Instantiate with a configuration * @param sparkContext the spark context to use * @param conf the configuration of the network */ public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerConfiguration conf, TrainingMaster trainingMaster) { this(new JavaSparkContext(sparkContext), initNetwork(conf), trainingMaster); } /** * Training constructor. Instantiate with a configuration * @param sc the spark context to use * @param conf the configuration of the network */ public SparkDl4jMultiLayer(JavaSparkContext sc, MultiLayerConfiguration conf, TrainingMaster trainingMaster) { this(sc.sc(),conf, trainingMaster); } public SparkDl4jMultiLayer(JavaSparkContext javaSparkContext, MultiLayerNetwork network, TrainingMaster trainingMaster){ sc = javaSparkContext; this.conf = network.getLayerWiseConfigurations().clone(); this.network = network; if(!network.isInitCalled()) network.init(); this.trainingMaster = trainingMaster; } private static MultiLayerNetwork initNetwork(MultiLayerConfiguration conf){ MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); return net; } public JavaSparkContext getSparkContext(){ return sc; } /** * @return The MultiLayerNetwork underlying the SparkDl4jMultiLayer */ public MultiLayerNetwork getNetwork() { return network; } /** * Set the network that underlies this SparkDl4jMultiLayer instacne * @param network network to set */ public void setNetwork(MultiLayerNetwork network) { this.network = network; } /** * Set whether training statistics should be collected for debugging purposes. Statistics collection is disabled by default * * @param collectTrainingStats If true: collect training statistics. If false: don't collect. */ public void setCollectTrainingStats(boolean collectTrainingStats){ trainingMaster.setCollectTrainingStats(collectTrainingStats); } /** * Get the training statistics, after collection of stats has been enabled using {@link #setCollectTrainingStats(boolean)} * * @return Training statistics */ public SparkTrainingStats getSparkTrainingStats(){ return trainingMaster.getTrainingStats(); } /** * Predict the given feature matrix * @param features the given feature matrix * @return the predictions */ public Matrix predict(Matrix features) { return MLLibUtil.toMatrix(network.output(MLLibUtil.toMatrix(features))); } /** * Predict the given vector * @param point the vector to predict * @return the predicted vector */ public Vector predict(Vector point) { return MLLibUtil.toVector(network.output(MLLibUtil.toVector(point))); } /** * Fit the DataSet RDD. Equivalent to fit(trainingData.toJavaRDD()) * * @param trainingData the training data RDD to fitDataSet * @return the MultiLayerNetwork after training */ public MultiLayerNetwork fit(RDD<DataSet> trainingData){ return fit(trainingData.toJavaRDD()); } /** * Fit the DataSet RDD * @param trainingData the training data RDD to fitDataSet * @return the MultiLayerNetwork after training */ public MultiLayerNetwork fit(JavaRDD<DataSet> trainingData) { trainingMaster.executeTraining(this,trainingData); return network; } /** * Fit a MultiLayerNetwork using Spark MLLib LabeledPoint instances. * This will convert the labeled points to the internal DL4J data format and train the model on that * @param rdd the rdd to fitDataSet * @return the multi layer network that was fitDataSet */ public MultiLayerNetwork fitLabeledPoint(JavaRDD<LabeledPoint> rdd) { int nLayers = network.getLayerWiseConfigurations().getConfs().size(); FeedForwardLayer ffl = (FeedForwardLayer)network.getLayerWiseConfigurations().getConf(nLayers-1).getLayer(); JavaRDD<DataSet> ds = MLLibUtil.fromLabeledPoint(sc, rdd, ffl.getNOut()); return fit(ds); } /** * This method allows you to specify IterationListeners for this model. * * PLEASE NOTE: * 1. These iteration listeners should be configured to use remote UiServer * 2. Remote UiServer should be accessible via network from Spark master node. * * @param listeners */ public void setListeners(@NonNull Collection<IterationListener> listeners) { this.listeners.clear(); this.listeners.addAll(listeners); if(trainingMaster != null) trainingMaster.setListeners(this.listeners); } protected void invokeListeners(MultiLayerNetwork network, int iteration) { for (IterationListener listener: listeners) { try { listener.iterationDone(network, iteration); } catch (Exception e) { log.error("Exception caught at IterationListener invocation" + e.getMessage()); e.printStackTrace(); } } } /** Gets the last (average) minibatch score from calling fit. This is the average score across all executors for the * last minibatch executed in each worker */ public double getScore(){ return lastScore; } public void setScore(double lastScore){ this.lastScore = lastScore; } /** * Overload of {@link #calculateScore(JavaRDD, boolean)} for {@code RDD<DataSet>} instead of {@code JavaRDD Other Java examples (source code examples)Here is a short list of links related to this Java SparkDl4jMultiLayer.java source code file: |
... this post is sponsored by my books ... | |
#1 New Release! |
FP Best Seller |
Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.
A percentage of advertising revenue from
pages under the /java/jwarehouse
URI on this website is
paid back to open source projects.