alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (RecordReaderMultiDataSetIteratorTest.java)

This example Java source code file (RecordReaderMultiDataSetIteratorTest.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

classpathresource, csvrecordreader, csvsequencerecordreader, exception, filesplit, indarray, multidatasetiterator, numberedfileinputsplit, recordreadermultidatasetiterator, sequencerecordreader, sequencerecordreaderdatasetiterator, should, string, test

The RecordReaderMultiDataSetIteratorTest.java Java example source code

package org.deeplearning4j.datasets.canova;

import org.canova.api.records.reader.RecordReader;
import org.canova.api.records.reader.SequenceRecordReader;
import org.canova.api.records.reader.impl.CSVRecordReader;
import org.canova.api.records.reader.impl.CSVSequenceRecordReader;
import org.canova.api.split.FileSplit;
import org.canova.api.split.NumberedFileInputSplit;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.io.ClassPathResource;

import static org.junit.Assert.*;

public class RecordReaderMultiDataSetIteratorTest {

    @Test
    public void testsBasic() throws Exception {

        //Load details from CSV files; single input/output -> compare to RecordReaderDataSetIterator
        RecordReader rr = new CSVRecordReader(0,",");
        rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
        RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr,10,4,3);

        RecordReader rr2 = new CSVRecordReader(0,",");
        rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));

        MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10)
                .addReader("reader",rr2)
                .addInput("reader",0,3)
                .addOutputOneHot("reader",4,3)
                .build();

        while(rrdsi.hasNext()){
            DataSet ds = rrdsi.next();
            INDArray fds = ds.getFeatureMatrix();
            INDArray lds = ds.getLabels();

            MultiDataSet mds = rrmdsi.next();
            assertEquals(1, mds.getFeatures().length);
            assertEquals(1, mds.getLabels().length);
            assertNull(mds.getFeaturesMaskArrays());
            assertNull(mds.getLabelsMaskArrays());
            INDArray fmds = mds.getFeatures(0);
            INDArray lmds = mds.getLabels(0);

            assertNotNull(fmds);
            assertNotNull(lmds);

            assertEquals(fds, fmds);
            assertEquals(lds,lmds);
        }
        assertFalse(rrmdsi.hasNext());


        //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
        ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
        String featuresPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");
        resource = new ClassPathResource("csvsequencelabels_0.txt");
        String labelsPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");

        SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
        featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReaderDataSetIterator iter =
                new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);

        SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
        featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1)
                .addSequenceReader("in",featureReader2)
                .addSequenceReader("out",labelReader2)
                .addInput("in")
                .addOutputOneHot("out",0,4)
                .build();

        while(iter.hasNext()){
            DataSet ds = iter.next();
            INDArray fds = ds.getFeatureMatrix();
            INDArray lds = ds.getLabels();

            MultiDataSet mds = srrmdsi.next();
            assertEquals(1, mds.getFeatures().length);
            assertEquals(1, mds.getLabels().length);
            assertNull(mds.getFeaturesMaskArrays());
            assertNull(mds.getLabelsMaskArrays());
            INDArray fmds = mds.getFeatures(0);
            INDArray lmds = mds.getLabels(0);

            assertNotNull(fmds);
            assertNotNull(lmds);

            assertEquals(fds, fmds);
            assertEquals(lds,lmds);
        }
        assertFalse(srrmdsi.hasNext());
    }

    @Test
    public void testSplittingCSV() throws Exception{
        //Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
        //Inputs: columns 0 and 1-2
        //Outputs: columns 3, and 4->OneHot

        RecordReader rr = new CSVRecordReader(0,",");
        rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
        RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr,10,4,3);

        RecordReader rr2 = new CSVRecordReader(0,",");
        rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));

        MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10)
                .addReader("reader",rr2)
                .addInput("reader",0,0)
                .addInput("reader",1,2)
                .addOutput("reader",3,3)
                .addOutputOneHot("reader", 4, 3)
                .build();

        while(rrdsi.hasNext()){
            DataSet ds = rrdsi.next();
            INDArray fds = ds.getFeatureMatrix();
            INDArray lds = ds.getLabels();

            MultiDataSet mds = rrmdsi.next();
            assertEquals(2, mds.getFeatures().length);
            assertEquals(2, mds.getLabels().length);
            assertNull(mds.getFeaturesMaskArrays());
            assertNull(mds.getLabelsMaskArrays());
            INDArray[] fmds = mds.getFeatures();
            INDArray[] lmds = mds.getLabels();

            assertNotNull(fmds);
            assertNotNull(lmds);
            for( int i=0; i<fmds.length; i++ ) assertNotNull(fmds[i]);
            for( int i=0; i<lmds.length; i++ ) assertNotNull(lmds[i]);

            //Get the subsets of the original iris data
            INDArray expIn1 = fds.get(NDArrayIndex.all(), NDArrayIndex.point(0));
            INDArray expIn2 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(1, 2, true));
            INDArray expOut1 = fds.get(NDArrayIndex.all(), NDArrayIndex.point(3));
            INDArray expOut2 = lds;

            assertEquals(expIn1,fmds[0]);
            assertEquals(expIn2,fmds[1]);
            assertEquals(expOut1,lmds[0]);
            assertEquals(expOut2,lmds[1]);
        }
        assertFalse(rrmdsi.hasNext());
    }

    @Test
    public void testSplittingCSVSequence() throws Exception {
        //Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
        // as standard one-hot output

        ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
        String featuresPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");
        resource = new ClassPathResource("csvsequencelabels_0.txt");
        String labelsPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");

        SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
        featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReaderDataSetIterator iter =
                new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);

        SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
        featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1)
                .addSequenceReader("seq1", featureReader2)
                .addSequenceReader("seq2", labelReader2)
                .addInput("seq1",0,1)
                .addInput("seq1", 2, 2)
                .addOutputOneHot("seq2",0,4)
                .build();

        while(iter.hasNext()){
            DataSet ds = iter.next();
            INDArray fds = ds.getFeatureMatrix();
            INDArray lds = ds.getLabels();

            MultiDataSet mds = srrmdsi.next();
            assertEquals(2, mds.getFeatures().length);
            assertEquals(1, mds.getLabels().length);
            assertNull(mds.getFeaturesMaskArrays());
            assertNull(mds.getLabelsMaskArrays());
            INDArray[] fmds = mds.getFeatures();
            INDArray[] lmds = mds.getLabels();

            assertNotNull(fmds);
            assertNotNull(lmds);
            for( int i=0; i<fmds.length; i++ ) assertNotNull(fmds[i]);
            for( int i=0; i<lmds.length; i++ ) assertNotNull(lmds[i]);

            INDArray expIn1 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 1, true), NDArrayIndex.all());
            INDArray expIn2 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 2, true), NDArrayIndex.all());

            assertEquals(expIn1,fmds[0]);
            assertEquals(expIn2,fmds[1]);
            assertEquals(lds,lmds[0]);
        }
        assertFalse(srrmdsi.hasNext());

    }


    @Test
    public void testInputValidation(){

        //Test: no readers
        try{
            MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addInput("something")
                    .addOutput("something")
                    .build();
            fail("Should have thrown exception");
        }catch(Exception e){ }

        //Test: reference to reader that doesn't exist
        try{
            RecordReader rr = new CSVRecordReader(0,",");
            rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));

            MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addReader("iris",rr)
                    .addInput("thisDoesntExist", 0, 3)
                    .addOutputOneHot("iris", 4, 3)
                    .build();
            fail("Should have thrown exception");
        }catch(Exception e){ }

        //Test: no inputs or outputs
        try{
            RecordReader rr = new CSVRecordReader(0,",");
            rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));

            MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addReader("iris", rr)
                    .build();
            fail("Should have thrown exception");
        }catch(Exception e){ }
    }

    @Test
    public void testVariableLengthTS() throws Exception {

        //Set up SequenceRecordReaderDataSetIterators for comparison
        ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
        String featuresPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");
        resource = new ClassPathResource("csvsequencelabelsShort_0.txt");
        String labelsPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");

        SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
        featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
        featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReaderDataSetIterator iterAlignStart =
                new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false,
                        SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START);

        SequenceRecordReaderDataSetIterator iterAlignEnd =
                new SequenceRecordReaderDataSetIterator(featureReader2, labelReader2, 1, 4, false,
                        SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);


        //Set up
        SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ",");
        featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ",");
        featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1)
                .addSequenceReader("in",featureReader3)
                .addSequenceReader("out",labelReader3)
                .addInput("in")
                .addOutputOneHot("out",0,4)
                .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START)
                .build();

        RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1)
                .addSequenceReader("in",featureReader4)
                .addSequenceReader("out",labelReader4)
                .addInput("in")
                .addOutputOneHot("out",0,4)
                .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END)
                .build();


        while(iterAlignStart.hasNext()) {
            DataSet dsStart = iterAlignStart.next();
            DataSet dsEnd = iterAlignEnd.next();

            MultiDataSet mdsStart = rrmdsiStart.next();
            MultiDataSet mdsEnd = rrmdsiEnd.next();

            assertEquals(1, mdsStart.getFeatures().length);
            assertEquals(1, mdsStart.getLabels().length);
            //assertEquals(1, mdsStart.getFeaturesMaskArrays().length); //Features data is always longer -> don't need mask arrays for it
            assertEquals(1, mdsStart.getLabelsMaskArrays().length);

            assertEquals(1, mdsEnd.getFeatures().length);
            assertEquals(1, mdsEnd.getLabels().length);
            //assertEquals(1, mdsEnd.getFeaturesMaskArrays().length);
            assertEquals(1, mdsEnd.getLabelsMaskArrays().length);


            assertEquals(dsStart.getFeatureMatrix(), mdsStart.getFeatures(0));
            assertEquals(dsStart.getLabels(), mdsStart.getLabels(0));
            assertEquals(dsStart.getLabelsMaskArray(), mdsStart.getLabelsMaskArray(0));

            assertEquals(dsEnd.getFeatureMatrix(), mdsEnd.getFeatures(0));
            assertEquals(dsEnd.getLabels(), mdsEnd.getLabels(0));
            assertEquals(dsEnd.getLabelsMaskArray(), mdsEnd.getLabelsMaskArray(0));
        }
        assertFalse(rrmdsiStart.hasNext());
        assertFalse(rrmdsiEnd.hasNext());
    }

}

Other Java examples (source code examples)

Here is a short list of links related to this Java RecordReaderMultiDataSetIteratorTest.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.