alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (TestCanovaDataSetFunctions.java)

This example Java source code file (TestCanovaDataSetFunctions.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

canovasequencepairdatasetfunction, csvsequencerecordreader, dataset, expect, file, indarray, javapairrdd, javardd, list, numberedfileinputsplit, sequencerecordreader, sequencerecordreaderdatasetiterator, spark, string, util

The TestCanovaDataSetFunctions.java Java example source code

package org.deeplearning4j.spark.canova;

import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.input.PortableDataStream;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.records.reader.SequenceRecordReader;
import org.canova.api.records.reader.impl.CSVSequenceRecordReader;
import org.canova.api.split.FileSplit;
import org.canova.api.split.InputSplit;
import org.canova.api.split.NumberedFileInputSplit;
import org.canova.api.util.ClassPathResource;
import org.canova.api.writable.Writable;
import org.canova.image.recordreader.ImageRecordReader;
import org.canova.spark.functions.SequenceRecordReaderFunction;
import org.canova.spark.functions.pairdata.*;
import org.canova.spark.util.CanovaSparkUtil;
import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.canova.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.spark.BaseSparkTest;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import scala.Tuple2;

import java.io.File;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

import static org.junit.Assert.*;

public class TestCanovaDataSetFunctions extends BaseSparkTest {

    @Test
    public void testCanovaDataSetFunction() throws Exception {
        JavaSparkContext sc = getContext();
        //Test Spark record reader functionality vs. local
        File f = new File("src/test/resources/imagetest/0/a.bmp");
        List<String> labelsList = Arrays.asList("0", "1");   //Need this for Spark: can't infer without init call

        String path = f.getPath();
        String folder = path.substring(0, path.length() - 7);
        path = folder + "*";

        JavaPairRDD<String,PortableDataStream> origData = sc.binaryFiles(path);
        assertEquals(4,origData.count());    //4 images

        RecordReader rr = new ImageRecordReader(28,28,1,true,labelsList);
        org.canova.spark.functions.RecordReaderFunction rrf = new org.canova.spark.functions.RecordReaderFunction(rr);
        JavaRDD<Collection rdd = origData.map(rrf);
        JavaRDD<DataSet> data = rdd.map(new CanovaDataSetFunction(1,2,false));
        List<DataSet> collected = data.collect();

        //Load normally (i.e., not via Spark), and check that we get the same results (order not withstanding)
        InputSplit is = new FileSplit(new File(folder),new String[]{"bmp"}, true);
        ImageRecordReader irr = new ImageRecordReader(28,28,1,true);
        irr.initialize(is);

        RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(irr,1,1,2);
        List<DataSet> listLocal = new ArrayList<>(4);
        while(iter.hasNext()){
            listLocal.add(iter.next());
        }


        //Compare:
        assertEquals(4,collected.size());
        assertEquals(4,listLocal.size());

        //Check that results are the same (order not withstanding)
        boolean[] found = new boolean[4];
        for (int i = 0; i < 4; i++) {
            int foundIndex = -1;
            DataSet ds = collected.get(i);
            for (int j = 0; j < 4; j++) {
                if (ds.equals(listLocal.get(j))) {
                    if (foundIndex != -1)
                        fail();    //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                    foundIndex = j;
                    if (found[foundIndex])
                        fail();   //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                    found[foundIndex] = true;   //mark this one as seen before
                }
            }
        }
        int count = 0;
        for (boolean b : found) if (b) count++;
        assertEquals(4, count);  //Expect all 4 and exactly 4 pairwise matches between spark and local versions
    }

    @Test
    public void testCanovaSequenceDataSetFunction() throws Exception {
        JavaSparkContext sc = getContext();
        //Test Spark record reader functionality vs. local

        File f = new File("src/test/resources/csvsequence/csvsequence_0.txt");
        String path = f.getPath();
        String folder = path.substring(0, path.length() - 17);
        path = folder + "*";

        JavaPairRDD<String,PortableDataStream> origData = sc.binaryFiles(path);
        assertEquals(3, origData.count());    //3 CSV sequences


        SequenceRecordReader seqRR = new CSVSequenceRecordReader(1,",");
        SequenceRecordReaderFunction rrf = new SequenceRecordReaderFunction(seqRR);
        JavaRDD<Collection> rdd = origData.map(rrf);
        JavaRDD<DataSet> data = rdd.map(new CanovaSequenceDataSetFunction(2, -1, true, null, null));
        List<DataSet> collected = data.collect();

        //Load normally (i.e., not via Spark), and check that we get the same results (order not withstanding)
        InputSplit is = new FileSplit(new File(folder),new String[]{"txt"}, true);
        SequenceRecordReader seqRR2 = new CSVSequenceRecordReader(1,",");
        seqRR2.initialize(is);

        SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(seqRR2,1,-1,2,true);
        List<DataSet> listLocal = new ArrayList<>(3);
        while(iter.hasNext()){
            listLocal.add(iter.next());
        }


        //Compare:
        assertEquals(3,collected.size());
        assertEquals(3,listLocal.size());

        //Check that results are the same (order not withstanding)
        boolean[] found = new boolean[3];
        for (int i = 0; i < 3; i++) {
            int foundIndex = -1;
            DataSet ds = collected.get(i);
            for (int j = 0; j < 3; j++) {
                if (ds.equals(listLocal.get(j))) {
                    if (foundIndex != -1)
                        fail();    //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                    foundIndex = j;
                    if (found[foundIndex])
                        fail();   //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                    found[foundIndex] = true;   //mark this one as seen before
                }
            }
        }
        int count = 0;
        for (boolean b : found) if (b) count++;
        assertEquals(3, count);  //Expect all 3 and exactly 3 pairwise matches between spark and local versions
    }

    @Test
    public void testCanovaSequencePairDataSetFunction() throws Exception {
        JavaSparkContext sc = getContext();

        //Convert data to a SequenceFile:
        File f = new File("src/test/resources/csvsequence/csvsequence_0.txt");
        String path = f.getPath();
        String folder = path.substring(0, path.length() - 17);
        path = folder + "*";

        PathToKeyConverter pathConverter = new PathToKeyConverterFilename();
        JavaPairRDD<Text,BytesPairWritable> toWrite = CanovaSparkUtil.combineFilesForSequenceFile(sc, path, path, pathConverter);

        Path p = Files.createTempDirectory("dl4j_testSeqPairFn");
        p.toFile().deleteOnExit();
        String outPath = p.toString() + "/out";
        new File(outPath).deleteOnExit();
        toWrite.saveAsNewAPIHadoopFile(outPath, Text.class, BytesPairWritable.class, SequenceFileOutputFormat.class);

        //Load from sequence file:
        JavaPairRDD<Text,BytesPairWritable> fromSeq = sc.sequenceFile(outPath, Text.class, BytesPairWritable.class);

        SequenceRecordReader srr1 = new CSVSequenceRecordReader(1,",");
        SequenceRecordReader srr2 = new CSVSequenceRecordReader(1,",");
        PairSequenceRecordReaderBytesFunction psrbf = new PairSequenceRecordReaderBytesFunction(srr1,srr2);
        JavaRDD<Tuple2,Collection>>> writables = fromSeq.map(psrbf);

            //Map to DataSet:
        CanovaSequencePairDataSetFunction pairFn = new CanovaSequencePairDataSetFunction();
        JavaRDD<DataSet> data = writables.map(pairFn);
        List<DataSet> sparkData = data.collect();


        //Now: do the same thing locally (SequenceRecordReaderDataSetIterator) and compare
        String featuresPath = f.getAbsolutePath().replaceAll("0", "%d");

        SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
        featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));

        SequenceRecordReaderDataSetIterator iter =
                new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, -1, true);

        List<DataSet> localData = new ArrayList<>(3);
        while(iter.hasNext()) localData.add(iter.next());

        assertEquals(3,sparkData.size());
        assertEquals(3,localData.size());

        for( int i=0; i<3; i++ ){
            //Check shapes etc. data sets order may differ for spark vs. local
            DataSet dsSpark = sparkData.get(i);
            DataSet dsLocal = localData.get(i);

            assertNull(dsSpark.getFeaturesMaskArray());
            assertNull(dsSpark.getLabelsMaskArray());

            INDArray fSpark = dsSpark.getFeatureMatrix();
            INDArray fLocal = dsLocal.getFeatureMatrix();
            INDArray lSpark = dsSpark.getLabels();
            INDArray lLocal = dsLocal.getLabels();

            int[] s = new int[]{1,3,4}; //1 example, 3 values, 3 time steps
            assertArrayEquals(s,fSpark.shape());
            assertArrayEquals(s,fLocal.shape());
            assertArrayEquals(s,lSpark.shape());
            assertArrayEquals(s,lLocal.shape());
        }


        //Check that results are the same (order not withstanding)
        boolean[] found = new boolean[3];
        for (int i = 0; i < 3; i++) {
            int foundIndex = -1;
            DataSet ds = sparkData.get(i);
            for (int j = 0; j < 3; j++) {
                if (ds.equals(localData.get(j))) {
                    if (foundIndex != -1)
                        fail();    //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                    foundIndex = j;
                    if (found[foundIndex])
                        fail();   //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                    found[foundIndex] = true;   //mark this one as seen before
                }
            }
        }
        int count = 0;
        for (boolean b : found) if (b) count++;
        assertEquals(3, count);  //Expect all 3 and exactly 3 pairwise matches between spark and local versions
    }

    @Test
    public void testCanovaSequencePairDataSetFunctionVariableLength() throws Exception {
        //Same sort of test as testCanovaSequencePairDataSetFunction() but with variable length time series (labels shorter, align end)

        //Convert data to a SequenceFile:
        File f = new File("src/test/resources/csvsequence/csvsequence_0.txt");
        String pathFeatures = f.getAbsolutePath();
        String folderFeatures = pathFeatures.substring(0, pathFeatures.length() - 17);
        pathFeatures = folderFeatures + "*";

        File f2 = new File("src/test/resources/csvsequencelabels/csvsequencelabelsShort_0.txt");
        String pathLabels = f2.getPath();
        String folderLabels = pathLabels.substring(0, pathLabels.length() - 28);
        pathLabels = folderLabels + "*";


        PathToKeyConverter pathConverter = new PathToKeyConverterNumber();  //Extract a number from the file name
        JavaPairRDD<Text,BytesPairWritable> toWrite = CanovaSparkUtil.combineFilesForSequenceFile(sc, pathFeatures, pathLabels, pathConverter);

        Path p = Files.createTempDirectory("dl4j_testSeqPairFnVarLength");
        p.toFile().deleteOnExit();
        String outPath = p.toString() + "/out";
        new File(outPath).deleteOnExit();
        toWrite.saveAsNewAPIHadoopFile(outPath, Text.class, BytesPairWritable.class, SequenceFileOutputFormat.class);

        //Load from sequence file:
        JavaPairRDD<Text,BytesPairWritable> fromSeq = sc.sequenceFile(outPath, Text.class, BytesPairWritable.class);

        SequenceRecordReader srr1 = new CSVSequenceRecordReader(1,",");
        SequenceRecordReader srr2 = new CSVSequenceRecordReader(1,",");
        PairSequenceRecordReaderBytesFunction psrbf = new PairSequenceRecordReaderBytesFunction(srr1,srr2);
        JavaRDD<Tuple2,Collection>>> writables = fromSeq.map(psrbf);

        //Map to DataSet:
        CanovaSequencePairDataSetFunction pairFn = new CanovaSequencePairDataSetFunction(4,false, CanovaSequencePairDataSetFunction.AlignmentMode.ALIGN_END);
        JavaRDD<DataSet> data = writables.map(pairFn);
        List<DataSet> sparkData = data.collect();


        //Now: do the same thing locally (SequenceRecordReaderDataSetIterator) and compare
        String featuresPath = f.getPath().replaceAll("0", "%d");
        String labelsPath = f2.getPath().replaceAll("0", "%d");

        SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
        featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReaderDataSetIterator iter =
                new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

        List<DataSet> localData = new ArrayList<>(3);
        while(iter.hasNext()) localData.add(iter.next());

        assertEquals(3,sparkData.size());
        assertEquals(3,localData.size());

        int[] fShapeExp = new int[]{1,3,4}; //1 example, 3 values, 4 time steps
        int[] lShapeExp = new int[]{1,4,4}; //1 example, 4 values/classes, 4 time steps (after padding)
        for( int i=0; i<3; i++ ){
            //Check shapes etc. data sets order may differ for spark vs. local
            DataSet dsSpark = sparkData.get(i);
            DataSet dsLocal = localData.get(i);

            assertNull(dsSpark.getFeaturesMaskArray());     //In this particular test: featuresLength > labelLength therefore no mast array needed for features
            assertNotNull(dsSpark.getLabelsMaskArray());    //Expect mask array for labels

            INDArray fSpark = dsSpark.getFeatureMatrix();
            INDArray fLocal = dsLocal.getFeatureMatrix();
            INDArray lSpark = dsSpark.getLabels();
            INDArray lLocal = dsLocal.getLabels();


            assertArrayEquals(fShapeExp,fSpark.shape());
            assertArrayEquals(fShapeExp,fLocal.shape());
            assertArrayEquals(lShapeExp,lSpark.shape());
            assertArrayEquals(lShapeExp,lLocal.shape());
        }


        //Check that results are the same (order not withstanding)
        boolean[] found = new boolean[3];
        for (int i = 0; i < 3; i++) {
            int foundIndex = -1;
            DataSet ds = sparkData.get(i);
            for (int j = 0; j < 3; j++) {
                if (ds.equals(localData.get(j))) {
                    if (foundIndex != -1)
                        fail();    //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                    foundIndex = j;
                    if (found[foundIndex])
                        fail();   //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                    found[foundIndex] = true;   //mark this one as seen before
                }
            }
        }
        int count = 0;
        for (boolean b : found) if (b) count++;
        assertEquals(3, count);  //Expect all 3 and exactly 3 pairwise matches between spark and local versions


        //-------------------------------------------------
        //NOW: test same thing, but for align start...
        CanovaSequencePairDataSetFunction pairFnAlignStart = new CanovaSequencePairDataSetFunction(4,false, CanovaSequencePairDataSetFunction.AlignmentMode.ALIGN_START);
        JavaRDD<DataSet> rddDataAlignStart = writables.map(pairFnAlignStart);
        List<DataSet> sparkDataAlignStart = rddDataAlignStart.collect();

        featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));   //re-initialize to reset
        labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
        SequenceRecordReaderDataSetIterator iterAlignStart =
                new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START);

        List<DataSet> localDataAlignStart = new ArrayList<>(3);
        while(iterAlignStart.hasNext()) localDataAlignStart.add(iterAlignStart.next());

        assertEquals(3,sparkDataAlignStart.size());
        assertEquals(3,localDataAlignStart.size());

        for( int i=0; i<3; i++ ){
            //Check shapes etc. data sets order may differ for spark vs. local
            DataSet dsSpark = sparkDataAlignStart.get(i);
            DataSet dsLocal = localDataAlignStart.get(i);

            assertNull(dsSpark.getFeaturesMaskArray());     //In this particular test: featuresLength > labelLength therefore no mast array needed for features
            assertNotNull(dsSpark.getLabelsMaskArray());    //Expect mask array for labels

            INDArray fSpark = dsSpark.getFeatureMatrix();
            INDArray fLocal = dsLocal.getFeatureMatrix();
            INDArray lSpark = dsSpark.getLabels();
            INDArray lLocal = dsLocal.getLabels();


            assertArrayEquals(fShapeExp,fSpark.shape());
            assertArrayEquals(fShapeExp,fLocal.shape());
            assertArrayEquals(lShapeExp,lSpark.shape());
            assertArrayEquals(lShapeExp,lLocal.shape());
        }


        //Check that results are the same (order not withstanding)
        found = new boolean[3];
        for (int i = 0; i < 3; i++) {
            int foundIndex = -1;
            DataSet ds = sparkData.get(i);
            for (int j = 0; j < 3; j++) {
                if (ds.equals(localData.get(j))) {
                    if (foundIndex != -1)
                        fail();    //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                    foundIndex = j;
                    if (found[foundIndex])
                        fail();   //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                    found[foundIndex] = true;   //mark this one as seen before
                }
            }
        }
        count = 0;
        for (boolean b : found) if (b) count++;
        assertEquals(3, count);  //Expect all 3 and exactly 3 pairwise matches between spark and local versions
    }


}

Other Java examples (source code examples)

Here is a short list of links related to this Java TestCanovaDataSetFunctions.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.