alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (RecordReaderDataSetiteratorTest.java)

This example Java source code file (RecordReaderDataSetiteratorTest.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

arraylist, classpathresource, collection, csvrecordreader, csvsequencerecordreader, exception, filesplit, indarray, intwritable, numberedfileinputsplit, sequencerecordreader, sequencerecordreaderdatasetiterator, string, test, util

The RecordReaderDataSetiteratorTest.java Java example source code

/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.datasets.canova;

import org.apache.commons.io.FilenameUtils;
import org.canova.api.io.data.IntWritable;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.records.reader.SequenceRecordReader;
import org.canova.api.records.reader.impl.CSVRecordReader;
import org.canova.api.records.reader.impl.CSVSequenceRecordReader;
import org.canova.api.records.reader.impl.CollectionRecordReader;
import org.canova.api.records.reader.impl.CollectionSequenceRecordReader;
import org.canova.api.split.FileSplit;
import org.canova.api.split.NumberedFileInputSplit;
import org.canova.api.writable.ArrayWritable;
import org.canova.api.writable.Writable;
import org.canova.common.data.NDArrayWritable;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.ReconstructionDataSetIterator;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;

import java.io.*;
import java.util.*;

import static org.junit.Assert.*;

/**
 * Created by agibsonccc on 3/6/15.
 */
public class RecordReaderDataSetiteratorTest {

    @Test
    public void testRecordReader() throws Exception {
        RecordReader recordReader = new CSVRecordReader();
        FileSplit csv = new FileSplit(new ClassPathResource("csv-example.csv").getFile());
        recordReader.initialize(csv);
        DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 34);
        DataSet next = iter.next();
        assertEquals(34, next.numExamples());
    }


    @Test
    public void testRecordReaderMaxBatchLimit() throws Exception {
        RecordReader recordReader = new CSVRecordReader();
        FileSplit csv = new FileSplit(new ClassPathResource("csv-example.csv").getFile());
        recordReader.initialize(csv);
        DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 10, -1, -1, 2);
        iter.next();
        iter.next();
        assertEquals(false, iter.hasNext());
    }

    @Test
    public void testRecordReaderMultiRegression() throws Exception {

        RecordReader csv = new CSVRecordReader();
        csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));

        int batchSize = 3;
        int labelIdxFrom = 3;
        int labelIdxTo = 4;

        DataSetIterator iter = new RecordReaderDataSetIterator(csv,batchSize,labelIdxFrom,labelIdxTo,true);
        DataSet ds = iter.next();

        INDArray f = ds.getFeatureMatrix();
        INDArray l = ds.getLabels();
        assertArrayEquals(new int[]{3,3}, f.shape());
        assertArrayEquals(new int[]{3,2}, l.shape());

        //Check values:
        double[][] fExpD = new double[][]{
                {5.1,3.5,1.4},
                {4.9,3.0,1.4},
                {4.7,3.2,1.3}};

        double[][] lExpD = new double[][]{
                {0.2,0},
                {0.2,0},
                {0.2,0}};

        INDArray fExp = Nd4j.create(fExpD);
        INDArray lExp = Nd4j.create(lExpD);

        assertEquals(fExp, f);
        assertEquals(lExp, l);

    }

    @Test
    public void testSequenceRecordReader() throws Exception {
        ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
        String featuresPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");
        resource = new ClassPathResource("csvsequencelabels_0.txt");
        String labelsPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");

        SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
        featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReaderDataSetIterator iter =
                new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);

        assertEquals(3, iter.inputColumns());
        assertEquals(4, iter.totalOutcomes());

        List<DataSet> dsList = new ArrayList<>();
        while (iter.hasNext()) {
            dsList.add(iter.next());
        }

        assertEquals(3, dsList.size());  //3 files
        for (int i = 0; i < 3; i++) {
            DataSet ds = dsList.get(i);
            INDArray features = ds.getFeatureMatrix();
            INDArray labels = ds.getLabels();
            assertEquals(1, features.size(0));   //1 example in mini-batch
            assertEquals(1, labels.size(0));
            assertEquals(3, features.size(1));   //3 values per line/time step
            assertEquals(4, labels.size(1));     //1 value per line, but 4 possible values -> one-hot vector
            assertEquals(4, features.size(2));   //sequence length = 4
            assertEquals(4, labels.size(2));
        }

        //Check features vs. expected:
        INDArray expF0 = Nd4j.create(1, 3, 4);
        expF0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{0, 1, 2}));
        expF0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{10, 11, 12}));
        expF0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{20, 21, 22}));
        expF0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{30, 31, 32}));
        assertEquals(dsList.get(0).getFeatureMatrix(), expF0);

        INDArray expF1 = Nd4j.create(1, 3, 4);
        expF1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{100, 101, 102}));
        expF1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{110, 111, 112}));
        expF1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{120, 121, 122}));
        expF1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{130, 131, 132}));
        assertEquals(dsList.get(1).getFeatureMatrix(), expF1);

        INDArray expF2 = Nd4j.create(1, 3, 4);
        expF2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{200, 201, 202}));
        expF2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{210, 211, 212}));
        expF2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{220, 221, 222}));
        expF2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{230, 231, 232}));
        assertEquals(dsList.get(2).getFeatureMatrix(), expF2);

        //Check labels vs. expected:
        INDArray expL0 = Nd4j.create(1, 4, 4);
        expL0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{1, 0, 0, 0}));
        expL0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{0, 1, 0, 0}));
        expL0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{0, 0, 1, 0}));
        expL0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{0, 0, 0, 1}));
        assertEquals(dsList.get(0).getLabels(), expL0);

        INDArray expL1 = Nd4j.create(1, 4, 4);
        expL1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{0, 0, 0, 1}));
        expL1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{0, 0, 1, 0}));
        expL1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{0, 1, 0, 0}));
        expL1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{1, 0, 0, 0}));
        assertEquals(dsList.get(1).getLabels(), expL1);

        INDArray expL2 = Nd4j.create(1, 4, 4);
        expL2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{0, 1, 0, 0}));
        expL2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{1, 0, 0, 0}));
        expL2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{0, 0, 0, 1}));
        expL2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{0, 0, 1, 0}));
        assertEquals(dsList.get(2).getLabels(), expL2);
    }

    @Test
    public void testSequenceRecordReaderRegression() throws Exception{
        ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
        String featuresPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");
        resource = new ClassPathResource("csvsequence_0.txt");
        String labelsPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");

        SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
        featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReaderDataSetIterator iter =
                new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 0, true);

        assertEquals(3, iter.inputColumns());
        assertEquals(3, iter.totalOutcomes());

        List<DataSet> dsList = new ArrayList<>();
        while (iter.hasNext()) {
            dsList.add(iter.next());
        }

        assertEquals(3, dsList.size());  //3 files
        for (int i = 0; i < 3; i++) {
            DataSet ds = dsList.get(i);
            INDArray features = ds.getFeatureMatrix();
            INDArray labels = ds.getLabels();
            assertArrayEquals(new int[]{1,3,4},features.shape());   //1 examples, 3 values, 4 time steps
            assertArrayEquals(new int[]{1,3,4},labels.shape());

            assertEquals(features,labels);
        }

        //Also test regression + reset from a single reader:
        featureReader.reset();
        iter = new SequenceRecordReaderDataSetIterator(featureReader, 1, 0, 2, true);
        int count = 0;
        while(iter.hasNext()){
            DataSet ds = iter.next();
            assertEquals(2,ds.getFeatureMatrix().size(1));
            assertEquals(1,ds.getLabels().size(1));
            count++;
        }
        assertEquals(3,count);


        iter.reset();
        count = 0;
        while(iter.hasNext()){
            iter.next();
            count++;
        }
        assertEquals(3,count);
    }

    @Test
    public void testSequenceRecordReaderReset() throws Exception {
        ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
        String featuresPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");
        resource = new ClassPathResource("csvsequencelabels_0.txt");
        String labelsPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");

        SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
        featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReaderDataSetIterator iter =
                new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);

        assertEquals(3, iter.inputColumns());
        assertEquals(4, iter.totalOutcomes());

        int nResets = 5;
        for( int i=0; i<nResets; i++ ) {
            iter.reset();
            int count = 0;
            while (iter.hasNext()) {
                DataSet ds = iter.next();
                INDArray features = ds.getFeatureMatrix();
                INDArray labels = ds.getLabels();
                assertArrayEquals(new int[]{1,3,4},features.shape());
                assertArrayEquals(new int[]{1,4,4},labels.shape());
                count++;
            }
            assertEquals(3,count);
        }
    }



    @Test
    public void testCSVLoadingRegression() throws Exception {
        int nLines = 30;
        int nFeatures = 5;
        int miniBatchSize = 10;
        int labelIdx = 0;

        String path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"),"rr_csv_test_rand.csv");
        double[][] data = makeRandomCSV(path,nLines,nFeatures);
        RecordReader testReader = new CSVRecordReader();
        testReader.initialize(new FileSplit(new File(path)));

        DataSetIterator iter = new RecordReaderDataSetIterator(testReader,null,miniBatchSize,labelIdx,1,true);
        int miniBatch = 0;
        while(iter.hasNext()){
            DataSet test = iter.next();
            INDArray features = test.getFeatureMatrix();
            INDArray labels = test.getLabels();
            assertArrayEquals(new int[]{miniBatchSize,nFeatures},features.shape());
            assertArrayEquals(new int[]{miniBatchSize, 1}, labels.shape());

            int startRow = miniBatch * miniBatchSize;
            for( int i=0; i<miniBatchSize; i++ ){
                double labelExp = data[startRow+i][labelIdx];
                double labelAct = labels.getDouble(i);
                assertEquals(labelExp,labelAct,1e-5f);

                int featureCount = 0;
                for( int j=0; j<nFeatures+1; j++ ){
                    if(j == labelIdx) continue;
                    double featureExp = data[startRow+i][j];
                    double featureAct = features.getDouble(i,featureCount++);
                    assertEquals(featureExp,featureAct,1e-5f);
                }
            }

            miniBatch++;
        }
        assertEquals(nLines/miniBatchSize,miniBatch);
    }


    public static double[][] makeRandomCSV(String tempFile, int nLines, int nFeatures) {
        File temp = new File(tempFile);
        temp.deleteOnExit();
        Random rand = new Random(12345);

        double[][] dArr = new double[nLines][nFeatures+1];

        try (PrintWriter out = new PrintWriter(new BufferedWriter(new FileWriter(temp)))) {
            for (int i = 0; i < nLines; i++) {
                dArr[i][0] = rand.nextDouble(); //First column: label
                out.print(dArr[i][0]);
                for (int j = 0; j < nFeatures; j++) {
                    dArr[i][j+1] = rand.nextDouble();
                    out.print("," + dArr[i][j+1]);
                }
                out.println();
            }
        } catch (IOException e) {
            e.printStackTrace();
        }

        return dArr;
    }

    @Test
    public void testVariableLengthSequence() throws Exception{
        ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
        String featuresPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");
        resource = new ClassPathResource("csvsequencelabelsShort_0.txt");
        String labelsPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");

        SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
        featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
        featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReaderDataSetIterator iterAlignStart =
                new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false,
                        SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START);

        SequenceRecordReaderDataSetIterator iterAlignEnd =
                new SequenceRecordReaderDataSetIterator(featureReader2, labelReader2, 1, 4, false,
                        SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

        assertEquals(3, iterAlignStart.inputColumns());
        assertEquals(4, iterAlignStart.totalOutcomes());

        assertEquals(3, iterAlignEnd.inputColumns());
        assertEquals(4, iterAlignEnd.totalOutcomes());

        List<DataSet> dsListAlignStart = new ArrayList<>();
        while (iterAlignStart.hasNext()) {
            dsListAlignStart.add(iterAlignStart.next());
        }

        List<DataSet> dsListAlignEnd = new ArrayList<>();
        while (iterAlignEnd.hasNext()) {
            dsListAlignEnd.add(iterAlignEnd.next());
        }

        assertEquals(3, dsListAlignStart.size());  //3 files
        assertEquals(3, dsListAlignEnd.size());  //3 files

        for (int i = 0; i < 3; i++) {
            DataSet ds = dsListAlignStart.get(i);
            INDArray features = ds.getFeatureMatrix();
            INDArray labels = ds.getLabels();
            assertEquals(1, features.size(0));   //1 example in mini-batch
            assertEquals(1, labels.size(0));
            assertEquals(3, features.size(1));   //3 values per line/time step
            assertEquals(4, labels.size(1));     //1 value per line, but 4 possible values -> one-hot vector
            assertEquals(4, features.size(2));   //sequence length = 4
            assertEquals(4, labels.size(2));

            DataSet ds2 = dsListAlignEnd.get(i);
            features = ds2.getFeatureMatrix();
            labels = ds2.getLabels();
            assertEquals(1, features.size(0));   //1 example in mini-batch
            assertEquals(1, labels.size(0));
            assertEquals(3, features.size(1));   //3 values per line/time step
            assertEquals(4, labels.size(1));     //1 value per line, but 4 possible values -> one-hot vector
            assertEquals(4, features.size(2));   //sequence length = 4
            assertEquals(4, labels.size(2));
        }

        //Check features vs. expected:
        //Here: labels always longer than features -> same features for align start and align end
        INDArray expF0 = Nd4j.create(1, 3, 4);
        expF0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{0, 1, 2}));
        expF0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{10, 11, 12}));
        expF0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{20, 21, 22}));
        expF0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{30, 31, 32}));
        assertEquals(expF0, dsListAlignStart.get(0).getFeatureMatrix());
        assertEquals(expF0, dsListAlignEnd.get(0).getFeatureMatrix());

        INDArray expF1 = Nd4j.create(1, 3, 4);
        expF1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{100, 101, 102}));
        expF1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{110, 111, 112}));
        expF1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{120, 121, 122}));
        expF1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{130, 131, 132}));
        assertEquals(expF1, dsListAlignStart.get(1).getFeatureMatrix());
        assertEquals(expF1, dsListAlignEnd.get(1).getFeatureMatrix());

        INDArray expF2 = Nd4j.create(1, 3, 4);
        expF2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{200, 201, 202}));
        expF2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{210, 211, 212}));
        expF2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{220, 221, 222}));
        expF2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{230, 231, 232}));
        assertEquals(expF2, dsListAlignStart.get(2).getFeatureMatrix());
        assertEquals(expF2, dsListAlignEnd.get(2).getFeatureMatrix());

        //Check features mask array:
        INDArray featuresMaskExpected = Nd4j.ones(1,4); //1 example, 4 values: same for both start/end align here
        for( int i=0; i<3; i++ ){
            INDArray featuresMaskStart = dsListAlignStart.get(i).getFeaturesMaskArray();
            INDArray featuresMaskEnd = dsListAlignEnd.get(i).getFeaturesMaskArray();
            assertEquals(featuresMaskExpected, featuresMaskStart);
            assertEquals(featuresMaskExpected, featuresMaskEnd);
        }


        //Check labels vs. expected:
        //First: aligning start
        INDArray expL0 = Nd4j.create(1, 4, 4);
        expL0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{1, 0, 0, 0}));
        expL0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{0, 1, 0, 0}));
        assertEquals(expL0, dsListAlignStart.get(0).getLabels());

        INDArray expL1 = Nd4j.create(1, 4, 4);
        expL1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{0, 1, 0, 0}));
        assertEquals(expL1, dsListAlignStart.get(1).getLabels());

        INDArray expL2 = Nd4j.create(1, 4, 4);
        expL2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{0, 0, 0, 1}));
        expL2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{0, 0, 1, 0}));
        expL2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{0, 1, 0, 0}));
        assertEquals(expL2, dsListAlignStart.get(2).getLabels());

        //Second: align end
        INDArray expL0end = Nd4j.create(1, 4, 4);
        expL0end.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{1, 0, 0, 0}));
        expL0end.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{0, 1, 0, 0}));
        assertEquals(expL0end, dsListAlignEnd.get(0).getLabels());

        INDArray expL1end = Nd4j.create(1, 4, 4);
        expL1end.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{0, 1, 0, 0}));
        assertEquals(expL1end, dsListAlignEnd.get(1).getLabels());

        INDArray expL2end = Nd4j.create(1, 4, 4);
        expL2end.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{0, 0, 0, 1}));
        expL2end.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{0, 0, 1, 0}));
        expL2end.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{0, 1, 0, 0}));
        assertEquals(expL2end, dsListAlignEnd.get(2).getLabels());

        //Check labels mask array
        INDArray[] labelsMaskExpectedStart = new INDArray[]{
                Nd4j.create(new float[]{1,1,0,0},new int[]{1,4}),
                Nd4j.create(new float[]{1,0,0,0},new int[]{1,4}),
                Nd4j.create(new float[]{1,1,1,0},new int[]{1,4})};
        INDArray[] labelsMaskExpectedEnd = new INDArray[]{
                Nd4j.create(new float[]{0,0,1,1},new int[]{1,4}),
                Nd4j.create(new float[]{0,0,0,1},new int[]{1,4}),
                Nd4j.create(new float[]{0,1,1,1},new int[]{1,4})};

        for( int i=0; i<3; i++ ){
            INDArray labelsMaskStart = dsListAlignStart.get(i).getLabelsMaskArray();
            INDArray labelsMaskEnd = dsListAlignEnd.get(i).getLabelsMaskArray();
            assertEquals(labelsMaskExpectedStart[i], labelsMaskStart);
            assertEquals(labelsMaskExpectedEnd[i], labelsMaskEnd);
        }
    }

    @Test
    public void testSequenceRecordReaderSingleReader() throws Exception{
        ClassPathResource resource = new ClassPathResource("csvsequenceSingle_0.txt");
        String path = resource.getFile().getAbsolutePath().replaceAll("0", "%d");

        SequenceRecordReader reader = new CSVSequenceRecordReader(1, ",");
        reader.initialize(new NumberedFileInputSplit(path, 0, 2));
        SequenceRecordReaderDataSetIterator iteratorClassification = new SequenceRecordReaderDataSetIterator(reader, 1, 3, 0, false);

        SequenceRecordReader reader2 = new CSVSequenceRecordReader(1, ",");
        reader2.initialize(new NumberedFileInputSplit(path, 0, 2));
        SequenceRecordReaderDataSetIterator iteratorRegression = new SequenceRecordReaderDataSetIterator(reader2, 1, 3, 0, true);

        INDArray expF0 = Nd4j.create(1, 2, 4);
        expF0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{1, 2}));
        expF0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{11, 12}));
        expF0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{21, 22}));
        expF0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{31, 32}));

        INDArray expF1 = Nd4j.create(1, 2, 4);
        expF1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{101, 102}));
        expF1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{111, 112}));
        expF1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{121, 122}));
        expF1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{131, 132}));

        INDArray expF2 = Nd4j.create(1, 2, 4);
        expF2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{201, 202}));
        expF2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{211, 212}));
        expF2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{221, 222}));
        expF2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{231, 232}));

        INDArray[] expF = new INDArray[]{expF0,expF1,expF2};

        //Expected out for classification:
        INDArray expOut0 = Nd4j.create(1, 3, 4);
        expOut0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{1, 0, 0}));
        expOut0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{0, 1, 0}));
        expOut0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{0, 0, 1}));
        expOut0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{1, 0, 0}));

        INDArray expOut1 = Nd4j.create(1, 3, 4);
        expOut1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{0, 1, 0}));
        expOut1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{0, 0, 1}));
        expOut1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{1, 0, 0}));
        expOut1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{0, 0, 1}));

        INDArray expOut2 = Nd4j.create(1, 3, 4);
        expOut2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{0, 1, 0}));
        expOut2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{1, 0, 0}));
        expOut2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{0, 1, 0}));
        expOut2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{0, 0, 1}));
        INDArray[] expOutClassification = new INDArray[]{expOut0,expOut1,expOut2};

        //Expected out for regression:
        INDArray expOutR0 = Nd4j.create(1, 1, 4);
        expOutR0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{0}));
        expOutR0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{1}));
        expOutR0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{2}));
        expOutR0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{0}));

        INDArray expOutR1 = Nd4j.create(1, 1, 4);
        expOutR1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{1}));
        expOutR1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{2}));
        expOutR1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{0}));
        expOutR1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{2}));

        INDArray expOutR2 = Nd4j.create(1, 1, 4);
        expOutR2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[]{1}));
        expOutR2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[]{0}));
        expOutR2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[]{1}));
        expOutR2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[]{2}));
        INDArray[] expOutRegression = new INDArray[]{expOutR0,expOutR1,expOutR2};

        int countC = 0;
        while(iteratorClassification.hasNext()){
            DataSet ds = iteratorClassification.next();
            INDArray f = ds.getFeatures();
            INDArray l = ds.getLabels();
            assertNull(ds.getFeaturesMaskArray());
            assertNull(ds.getLabelsMaskArray());

            assertArrayEquals(new int[]{1, 2, 4}, f.shape());
            assertArrayEquals(new int[]{1, 3, 4}, l.shape()); //One-hot representation

            assertEquals(expF[countC], f);
            assertEquals(expOutClassification[countC++],l);
        }
        assertEquals(3,countC);
        assertEquals(3,iteratorClassification.totalOutcomes());

        int countF = 0;
        while(iteratorRegression.hasNext()){
            DataSet ds = iteratorRegression.next();
            INDArray f = ds.getFeatures();
            INDArray l = ds.getLabels();
            assertNull(ds.getFeaturesMaskArray());
            assertNull(ds.getLabelsMaskArray());

            assertArrayEquals(new int[]{1, 2, 4}, f.shape());
            assertArrayEquals(new int[]{1, 1, 4}, l.shape());   //Regression (single output)

            assertEquals(expF[countF], f);
            assertEquals(expOutRegression[countF++], l);
        }
        assertEquals(3,countF);
        assertEquals(1,iteratorRegression.totalOutcomes());
    }



    @Test
    public void testSeqRRDSIArrayWritableOneReader(){

        Collection<Collection sequence1 = new ArrayList<>();
        sequence1.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{1,2,3})), new IntWritable(0)));
        sequence1.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{4,5,6})), new IntWritable(1)));
        Collection<Collection sequence2 = new ArrayList<>();
        sequence2.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{7,8,9})), new IntWritable(2)));
        sequence2.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{10,11,12})), new IntWritable(3)));


        SequenceRecordReader rr = new CollectionSequenceRecordReader(Arrays.asList(sequence1,sequence2));

        SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rr,2,4,1,false);

        DataSet ds = iter.next();

        INDArray expFeatures = Nd4j.create(2,3,2);  //2 examples, 3 values per time step, 2 time steps
        expFeatures.tensorAlongDimension(0,1,2).assign(Nd4j.create(new double[][]{{1,4},{2,5},{3,6}}));
        expFeatures.tensorAlongDimension(1,1,2).assign(Nd4j.create(new double[][]{{7,10},{8,11},{9,12}}));

        INDArray expLabels = Nd4j.create(2,4,2);
        expLabels.tensorAlongDimension(0,1,2).assign(Nd4j.create(new double[][]{{1,0},{0,1},{0,0},{0,0}}));
        expLabels.tensorAlongDimension(1,1,2).assign(Nd4j.create(new double[][]{{0,0},{0,0},{1,0},{0,1}}));

        assertEquals(expFeatures, ds.getFeatureMatrix());
        assertEquals(expLabels, ds.getLabels());
    }

    @Test
    public void testSeqRRDSIArrayWritableOneReaderRegression(){
        //Regression, where the output is an array writable
        Collection<Collection sequence1 = new ArrayList<>();
        sequence1.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{1,2,3})), new NDArrayWritable(Nd4j.create(new double[]{100,200,300}))));
        sequence1.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{4,5,6})), new NDArrayWritable(Nd4j.create(new double[]{400,500,600}))));
        Collection<Collection sequence2 = new ArrayList<>();
        sequence2.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{7,8,9})), new NDArrayWritable(Nd4j.create(new double[]{700,800,900}))));
        sequence2.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{10,11,12})), new NDArrayWritable(Nd4j.create(new double[]{1000,1100,1200}))));


        SequenceRecordReader rr = new CollectionSequenceRecordReader(Arrays.asList(sequence1,sequence2));

        SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rr,2,-1,1,true);

        DataSet ds = iter.next();

        INDArray expFeatures = Nd4j.create(2,3,2);  //2 examples, 3 values per time step, 2 time steps
        expFeatures.tensorAlongDimension(0,1,2).assign(Nd4j.create(new double[][]{{1,4},{2,5},{3,6}}));
        expFeatures.tensorAlongDimension(1,1,2).assign(Nd4j.create(new double[][]{{7,10},{8,11},{9,12}}));

        INDArray expLabels = Nd4j.create(2,3,2);
        expLabels.tensorAlongDimension(0,1,2).assign(Nd4j.create(new double[][]{{100,400},{200,500},{300,600}}));
        expLabels.tensorAlongDimension(1,1,2).assign(Nd4j.create(new double[][]{{700,1000},{800,1100},{900,1200}}));

        assertEquals(expFeatures, ds.getFeatureMatrix());
        assertEquals(expLabels, ds.getLabels());
    }

    @Test
    public void testSeqRRDSIMultipleArrayWritablesOneReader(){
        //Input with multiple array writables:

        Collection<Collection sequence1 = new ArrayList<>();
        sequence1.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{1,2,3})), new NDArrayWritable(Nd4j.create(new double[]{100,200,300})), new IntWritable(0)));
        sequence1.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{4,5,6})), new NDArrayWritable(Nd4j.create(new double[]{400,500,600})), new IntWritable(1)));
        Collection<Collection sequence2 = new ArrayList<>();
        sequence2.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{7,8,9})), new NDArrayWritable(Nd4j.create(new double[]{700,800,900})), new IntWritable(2)));
        sequence2.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{10,11,12})), new NDArrayWritable(Nd4j.create(new double[]{1000,1100,1200})), new IntWritable(3)));


        SequenceRecordReader rr = new CollectionSequenceRecordReader(Arrays.asList(sequence1,sequence2));

        SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rr,2,4,2,false);

        DataSet ds = iter.next();

        INDArray expFeatures = Nd4j.create(2,6,2);  //2 examples, 6 values per time step, 2 time steps
        expFeatures.tensorAlongDimension(0,1,2).assign(Nd4j.create(new double[][]{{1,4},{2,5},{3,6},{100,400},{200,500},{300,600}}));
        expFeatures.tensorAlongDimension(1,1,2).assign(Nd4j.create(new double[][]{{7,10},{8,11},{9,12},{700,1000},{800,1100},{900,1200}}));

        INDArray expLabels = Nd4j.create(2,4,2);
        expLabels.tensorAlongDimension(0,1,2).assign(Nd4j.create(new double[][]{{1,0},{0,1},{0,0},{0,0}}));
        expLabels.tensorAlongDimension(1,1,2).assign(Nd4j.create(new double[][]{{0,0},{0,0},{1,0},{0,1}}));

        assertEquals(expFeatures, ds.getFeatureMatrix());
        assertEquals(expLabels, ds.getLabels());
    }

    @Test
    public void testSeqRRDSIArrayWritableTwoReaders(){
        Collection<Collection sequence1 = new ArrayList<>();
        sequence1.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{1,2,3})), new IntWritable(100)));
        sequence1.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{4,5,6})), new IntWritable(200)));
        Collection<Collection sequence2 = new ArrayList<>();
        sequence2.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{7,8,9})), new IntWritable(300)));
        sequence2.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{10,11,12})), new IntWritable(400)));
        SequenceRecordReader rrFeatures = new CollectionSequenceRecordReader(Arrays.asList(sequence1,sequence2));

        Collection<Collection sequence1L = new ArrayList<>();
        sequence1L.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{100,200,300})), new IntWritable(101)));
        sequence1L.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{400,500,600})), new IntWritable(201)));
        Collection<Collection sequence2L = new ArrayList<>();
        sequence2L.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{700,800,900})), new IntWritable(301)));
        sequence2L.add(Arrays.asList((Writable)new NDArrayWritable(Nd4j.create(new double[]{1000,1100,1200})), new IntWritable(401)));
        SequenceRecordReader rrLabels = new CollectionSequenceRecordReader(Arrays.asList(sequence1L,sequence2L));

        SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rrFeatures,rrLabels,2,-1,true);

        INDArray expFeatures = Nd4j.create(2,4,2);  //2 examples, 4 values per time step, 2 time steps
        expFeatures.tensorAlongDimension(0,1,2).assign(Nd4j.create(new double[][]{{1,4},{2,5},{3,6},{100,200}}));
        expFeatures.tensorAlongDimension(1,1,2).assign(Nd4j.create(new double[][]{{7,10},{8,11},{9,12},{300,400}}));

        INDArray expLabels = Nd4j.create(2,4,2);
        expLabels.tensorAlongDimension(0,1,2).assign(Nd4j.create(new double[][]{{100,400},{200,500},{300,600},{101,201}}));
        expLabels.tensorAlongDimension(1,1,2).assign(Nd4j.create(new double[][]{{700,1000},{800,1100},{900,1200},{301,401}}));

        DataSet ds = iter.next();
        assertEquals(expFeatures, ds.getFeatureMatrix());
        assertEquals(expLabels, ds.getLabels());
    }


}

Other Java examples (source code examples)

Here is a short list of links related to this Java RecordReaderDataSetiteratorTest.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.