|
Java example source code file (TestSparkComputationGraph.java)
The TestSparkComputationGraph.java Java example source codepackage org.deeplearning4j.spark.impl.graph; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.canova.api.records.reader.RecordReader; import org.canova.api.records.reader.impl.CSVRecordReader; import org.canova.api.split.FileSplit; import org.deeplearning4j.datasets.canova.RecordReaderMultiDataSetIterator; import org.deeplearning4j.optimize.api.IterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.spark.api.TrainingMaster; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.spark.BaseSparkTest; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.lossfunctions.LossFunctions; import scala.Tuple2; import java.util.*; import static org.junit.Assert.assertEquals; public class TestSparkComputationGraph extends BaseSparkTest { @Test public void testBasic() throws Exception { JavaSparkContext sc = this.sc; RecordReader rr = new CSVRecordReader(0, ","); rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile())); MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(1) .addReader("iris", rr) .addInput("iris", 0, 3) .addOutputOneHot("iris", 4, 3) .build(); List<MultiDataSet> list = new ArrayList<>(150); while (iter.hasNext()) list.add(iter.next()); ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .learningRate(0.1) .graphBuilder() .addInputs("in") .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in") .addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense") .setOutputs("out") .pretrain(false).backprop(true) .build(); ComputationGraph cg = new ComputationGraph(config); cg.init(); TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 10, 1, 0); SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm); scg.setListeners(Collections.singleton((IterationListener)new ScoreIterationListener(1))); JavaRDD<MultiDataSet> rdd = sc.parallelize(list); scg.fitMultiDataSet(rdd); //Try: fitting using DataSet DataSetIterator iris = new IrisDataSetIterator(1, 150); List<DataSet> list2 = new ArrayList<>(); while (iris.hasNext()) list2.add(iris.next()); JavaRDD<DataSet> rddDS = sc.parallelize(list2); scg.fit(rddDS); } @Test public void testDistributedScoring() { ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() .regularization(true).l1(0.1).l2(0.1) .seed(123) .updater(Updater.NESTEROVS) .learningRate(0.1) .momentum(0.9) .graphBuilder() .addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder() .nIn(nIn).nOut(3) .activation("tanh").build(), "in") .addLayer("1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .nIn(3).nOut(nOut) .activation("softmax") .build(), "0") .setOutputs("1") .backprop(true) .pretrain(false) .build(); TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 10, 1, 0); SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm); ComputationGraph netCopy = sparkNet.getNetwork().clone(); int nRows = 100; INDArray features = Nd4j.rand(nRows, nIn); INDArray labels = Nd4j.zeros(nRows, nOut); Random r = new Random(12345); for (int i = 0; i < nRows; i++) { labels.putScalar(new int[]{i, r.nextInt(nOut)}, 1.0); } INDArray localScoresWithReg = netCopy.scoreExamples(new DataSet(features, labels), true); INDArray localScoresNoReg = netCopy.scoreExamples(new DataSet(features, labels), false); List<Tuple2 Other Java examples (source code examples)Here is a short list of links related to this Java TestSparkComputationGraph.java source code file: |
... this post is sponsored by my books ... | |
#1 New Release! |
FP Best Seller |
Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.
A percentage of advertising revenue from
pages under the /java/jwarehouse
URI on this website is
paid back to open source projects.