home | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (TestTrainingStatsCollection.java)

This example Java source code file (TestTrainingStatsCollection.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

arraylist, class, collection, commonsparktrainingstats, dataset, indarray, object, parameteraveragingtrainingmasterstats, parameteraveragingtrainingworkerstats, reflection, sparkconf, sparkdl4jmultilayer, sparktrainingstats, string, test, util

The TestTrainingStatsCollection.java Java example source code

package org.deeplearning4j.spark.impl.stats;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.spark.api.stats.CommonSparkTrainingStats;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMasterStats;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingWorkerStats;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;

import static junit.framework.TestCase.assertNotNull;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;

/**
 * Created by Alex on 17/06/2016.
 */
public class TestTrainingStatsCollection {

    @Test
    public void testStatsCollection() throws Exception {

        int nWorkers = 4;

        SparkConf sparkConf = new SparkConf();
        sparkConf.setMaster("local[" + nWorkers + "]");
        sparkConf.setAppName("Test");

        JavaSparkContext sc = new JavaSparkContext(sparkConf);

        try {

            MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .iterations(1)
                    .list()
                    .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build())
                    .layer(1, new OutputLayer.Builder().nIn(10).nOut(10).build())
                    .pretrain(false).backprop(true)
                    .build();

            int miniBatchSizePerWorker = 10;
            int averagingFrequency = 5;
            int numberOfAveragings = 3;

            int totalExamples = nWorkers * miniBatchSizePerWorker * averagingFrequency * numberOfAveragings;

            Nd4j.getRandom().setSeed(12345);
            List<DataSet> list = new ArrayList<>();
            for (int i = 0; i < totalExamples; i++) {
                INDArray f = Nd4j.rand(1, 10);
                INDArray l = Nd4j.rand(1, 10);
                DataSet ds = new DataSet(f, l);
                list.add(ds);
            }

            JavaRDD<DataSet> rdd = sc.parallelize(list);
            rdd.repartition(4);

            ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(nWorkers)
                    .averagingFrequency(averagingFrequency)
                    .batchSizePerWorker(miniBatchSizePerWorker)
                    .saveUpdater(true)
                    .workerPrefetchNumBatches(0)
                    .build();

            SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf, tm);
            sparkNet.setCollectTrainingStats(true);
            sparkNet.fit(rdd);


            //Collect the expected keys:
            List<String> expectedStatNames = new ArrayList<>();
            Class<?>[] classes = new Class[]{CommonSparkTrainingStats.class, ParameterAveragingTrainingMasterStats.class, ParameterAveragingTrainingWorkerStats.class};
            String[] fieldNames = new String[]{"columnNames", "columnNames", "columnNames"};
            for (int i = 0; i < classes.length; i++) {
                Field field = classes[i].getDeclaredField(fieldNames[i]);
                field.setAccessible(true);
                Object f = field.get(null);
                Collection<String> c = (Collection) f;
                expectedStatNames.addAll(c);
            }

            System.out.println(expectedStatNames);


            SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
            Set<String> actualKeySet = stats.getKeySet();
            assertEquals(expectedStatNames.size(), actualKeySet.size());
            for (String s : stats.getKeySet()) {
                assertTrue(expectedStatNames.contains(s));
                Object o = stats.getValue(s);
            }

            String statsAsString = stats.statsAsString();
            System.out.println(statsAsString);
            assertEquals(actualKeySet.size(), statsAsString.split("\n").length);    //One line per stat


            //Go through nested stats
            //First: master stats
            assertTrue(stats instanceof ParameterAveragingTrainingMasterStats);
            ParameterAveragingTrainingMasterStats masterStats = (ParameterAveragingTrainingMasterStats) stats;
            int[] broadcastCreateTime = masterStats.getParameterAveragingMasterBroadcastCreateTimesMs();
            assertEquals(numberOfAveragings, broadcastCreateTime.length);
            assertGreaterEqZero(broadcastCreateTime);

            int[] fitTimes = masterStats.getParameterAveragingMasterFitTimesMs();
            assertEquals(1, fitTimes.length);   //i.e., number of times fit(JavaRDD<DataSet>) was called
            assertGreaterZero(fitTimes);

            int[] splitTimes = masterStats.getParameterAveragingMasterSplitTimesMs();
            assertEquals(1, splitTimes.length);     //Splitting of the data set is executed once only (i.e., one fit(JavaRDD<DataSet>) call)
            assertGreaterEqZero(splitTimes);

            int[] aggregateTimesMs = masterStats.getParamaterAveragingMasterAggregateTimesMs();
            assertEquals(numberOfAveragings, aggregateTimesMs.length);
            assertGreaterEqZero(aggregateTimesMs);

            int[] processParamsTimesMs = masterStats.getParameterAveragingMasterProcessParamsUpdaterTimesMs();
            assertEquals(numberOfAveragings, processParamsTimesMs.length);
            assertGreaterEqZero(processParamsTimesMs);

            //Second: Common spark training stats
            SparkTrainingStats commonStats = masterStats.getNestedTrainingStats();
            assertNotNull(commonStats);
            assertTrue(commonStats instanceof CommonSparkTrainingStats);
            CommonSparkTrainingStats cStats = (CommonSparkTrainingStats) commonStats;
            int[] workerFlatMapTotalTimeMs = cStats.getWorkerFlatMapTotalTimeMs();
            assertEquals(numberOfAveragings * nWorkers, workerFlatMapTotalTimeMs.length);
            assertGreaterZero(workerFlatMapTotalTimeMs);

            int[] workerFlatMapTotalExampleCount = cStats.getWorkerFlatMapTotalExampleCount();
            assertEquals(numberOfAveragings * nWorkers, workerFlatMapTotalExampleCount.length);
            assertGreaterZero(workerFlatMapTotalExampleCount);

            int[] workerFlatMapGetInitialModelTimeMs = cStats.getWorkerFlatMapGetInitialModelTimeMs();
            assertEquals(numberOfAveragings * nWorkers, workerFlatMapGetInitialModelTimeMs.length);
            assertGreaterEqZero(workerFlatMapGetInitialModelTimeMs);

            int[] workerFlatMapDataSetGetTimesMs = cStats.getWorkerFlatMapDataSetGetTimesMs();
            int numMinibatchesProcessed = workerFlatMapDataSetGetTimesMs.length;
            int expectedNumMinibatchesProcessed = numberOfAveragings * nWorkers * averagingFrequency;   //1 for every time we get a data set

            //Sometimes random split is just bad - some executors might miss out on getting the expected amount of data
            assertTrue(numMinibatchesProcessed >= expectedNumMinibatchesProcessed - 5);

            int workerFlatMapCountNoDataInstances = cStats.getWorkerFlatMapCountNoDataInstances();
            if(numMinibatchesProcessed == expectedNumMinibatchesProcessed){
                assertEquals(0, workerFlatMapCountNoDataInstances);
            }
            assertGreaterEqZero(workerFlatMapDataSetGetTimesMs);

            int[] workerFlatMapProcessMiniBatchTimesMs = cStats.getWorkerFlatMapProcessMiniBatchTimesMs();
            assertTrue(workerFlatMapProcessMiniBatchTimesMs.length >= numberOfAveragings * nWorkers * averagingFrequency - 5 );
            assertGreaterEqZero(workerFlatMapProcessMiniBatchTimesMs);

            //Third: ParameterAveragingTrainingWorker stats
            SparkTrainingStats paramAvgStats = cStats.getNestedTrainingStats();
            assertNotNull(paramAvgStats);
            assertTrue(paramAvgStats instanceof ParameterAveragingTrainingWorkerStats);

            ParameterAveragingTrainingWorkerStats pStats = (ParameterAveragingTrainingWorkerStats) paramAvgStats;
            int[] parameterAveragingWorkerBroadcastGetValueTimeMs = pStats.getParameterAveragingWorkerBroadcastGetValueTimeMs();
            assertEquals(numberOfAveragings * nWorkers, parameterAveragingWorkerBroadcastGetValueTimeMs.length);
            assertGreaterEqZero(parameterAveragingWorkerBroadcastGetValueTimeMs);

            int[] parameterAveragingWorkerInitTimeMs = pStats.getParameterAveragingWorkerInitTimeMs();
            assertEquals(numberOfAveragings * nWorkers, parameterAveragingWorkerInitTimeMs.length);
            assertGreaterEqZero(parameterAveragingWorkerInitTimeMs);

            int[] parameterAveragingWorkerFitTimesMs = pStats.getParameterAveragingWorkerFitTimesMs();
            assertTrue(parameterAveragingWorkerFitTimesMs.length >= numberOfAveragings * nWorkers * averagingFrequency - 5);
            assertGreaterEqZero(parameterAveragingWorkerFitTimesMs);

            assertNull(pStats.getNestedTrainingStats());
        } finally {
            sc.stop();
        }
    }

    private static void assertGreaterEqZero(int[] array){
        for(int i : array) assertTrue(i >= 0);
    }

    private static void assertGreaterZero(int[] array){
        for(int i : array) assertTrue(i > 0);
    }
}

Other Java examples (source code examples)

Here is a short list of links related to this Java TestTrainingStatsCollection.java source code file:



my book on functional programming

 

new blog posts

 

Copyright 1998-2019 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.