alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (CifarDataSetIterator.java)

This example Java source code file (CifarDataSetIterator.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

arraylist, bytebuffer, cifardatasetiterator, cifarloader, cv_8uc, exception, imagetransform, indarray, inputstream, ioexception, list, mat, nio, override, recordreaderdatasetiterator, util

The CifarDataSetIterator.java Java example source code

package org.deeplearning4j.datasets.iterator.impl;

import org.canova.image.loader.CifarLoader;
import org.canova.image.transform.ImageTransform;
import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;

import static org.bytedeco.javacpp.opencv_core.*;

/**
 * Created by nyghtowl on 12/18/15.
 */
public class CifarDataSetIterator extends RecordReaderDataSetIterator {

    protected static int height = 32;
    protected static int width = 32;
    protected static int channels = 3;
    protected static int numPixels = 3073;
    protected static CifarLoader loader;
    protected static InputStream inputStream = null;
    protected int totalExamples = CifarLoader.NUM_TRAIN_IMAGES + CifarLoader.NUM_TRAIN_IMAGES;
    // TODO use maxNumBatches and batchNum instead
    protected int numExamples = totalExamples;
    protected int exampleCount = 0;


    /** Loads images with given  batchSize, numExamples, & version returned by the generator. */
    public CifarDataSetIterator(int batchSize, int numExamples,  boolean train) {
        this(batchSize, numExamples, new int[]{height, width, channels}, CifarLoader.NUM_LABELS, null, 0, train);
    }

    /** Loads images with given  batchSize, numExamples, & imgDim returned by the generator. */
    public CifarDataSetIterator(int batchSize, int numExamples, int[] imgDim) {
        this(batchSize, numExamples, imgDim, CifarLoader.NUM_LABELS, null, 0, true);
    }

    /** Loads images with given  batchSize, numExamples, imgDim & version returned by the generator. */
    public CifarDataSetIterator(int batchSize, int numExamples, int[] imgDim, boolean train) {
        this(batchSize, numExamples, imgDim, CifarLoader.NUM_LABELS, null, 0, train);
    }

    /** Loads images with given  batchSize & numExamples returned by the generator. */
    public CifarDataSetIterator(int batchSize, int numExamples) {
        this(batchSize, numExamples, new int[]{height, width, channels}, CifarLoader.NUM_LABELS, null, 0, true);
    }

    /** Loads images with given  batchSize & imgDim returned by the generator. */
    public CifarDataSetIterator(int batchSize, int[] imgDim)  {
        this(batchSize, CifarLoader.NUM_TRAIN_IMAGES, imgDim, CifarLoader.NUM_LABELS, null, 0, true);
    }

    /** Loads images with given  batchSize, numExamples, imgDim & version returned by the generator. */
    public CifarDataSetIterator(int batchSize, int numExamples, int[] imgDim, int normalizeValue, boolean train) {
        this(batchSize, numExamples, imgDim, CifarLoader.NUM_LABELS, null, normalizeValue, train);
    }

    /**
     * Create Cifar data specific iterator
     * @param batchSize the batch size of the examples
     * @param imgDim an array of height, width and channels
     * @param numExamples the overall number of examples
     * @param imageTransform the transformation to apply to the images
     * @param normalizeValue value to normalize the image data
     * @param train true if use training set and false for test
     * */
    public CifarDataSetIterator(int batchSize, int numExamples, int[] imgDim, int numPossibleLables, ImageTransform imageTransform, int normalizeValue, boolean train) {
        super(null, batchSize, 1, numExamples);
        this.loader = new CifarLoader(imgDim[0], imgDim[1], imgDim[2], imageTransform, normalizeValue, train);
        this.numExamples = numExamples > totalExamples? totalExamples: numExamples;
        this.numPossibleLabels = numPossibleLables;
        this.inputStream  = loader.getInputStream();
    }

    @Override
    public DataSet next(int num) {
        if(useCurrent) {
            useCurrent = false;
            if(preProcessor != null) preProcessor.preProcess(last);
            return last;
        }

        int batchNumCount = 0;
        byte[] byteFeature = new byte[numPixels];
        List<DataSet> dataSets = new ArrayList<>();
        INDArray label; // first value in the 3073 byte array
        Mat image = new Mat(height, width, CV_8UC(channels)); // feature are 3072
        ByteBuffer imageData = image.createBuffer();

        try {
            while((inputStream.read(byteFeature)) != -1 && batchNumCount != num) {
                label = FeatureUtil.toOutcomeVector(byteFeature[0], numPossibleLabels);
                for (int i = 0; i < height * width; i++) {
                    imageData.put(3 * i,     byteFeature[i + 1 + 2 * height * width]); // blue
                    imageData.put(3 * i + 1, byteFeature[i + 1 +     height * width]); // green
                    imageData.put(3 * i + 2, byteFeature[i + 1                     ]); // red
                }
                dataSets.add(new DataSet(loader.asRowVector(image), label));
                batchNumCount++;
            }
            exampleCount += batchSize;
        } catch (IOException e) {
            e.printStackTrace();
        }

        List<INDArray> inputs = new ArrayList<>();
        List<INDArray> labels = new ArrayList<>();

        for (DataSet data : dataSets) {
            inputs.add(data.getFeatureMatrix());
            labels.add(data.getLabels());
        }

        if(inputs.isEmpty() || (maxNumBatches > -1 && batchNum >= maxNumBatches)) {
            notOvershot = false;
            return last;
        }

        DataSet ret =  new DataSet(Nd4j.vstack(inputs.toArray(new INDArray[0])), Nd4j.vstack(labels.toArray(new INDArray[0])));
        last = ret;
        if(preProcessor != null) preProcessor.preProcess(ret);
        if ( loader.getLabels() != null) ret.setLabelNames(loader.getLabels());
        return ret;
    }

    @Override
    public boolean hasNext() {
        try {
            return exampleCount < numExamples;
        } catch (Exception e) {
            e.printStackTrace();
        }
        return false;
    }

    @Override
    public int totalExamples() {
        return totalExamples;
    }

    @Override
    public void reset() {
        exampleCount = 0;
        inputStream = loader.getInputStream();
    }

    @Override
    public List<String> getLabels(){
        return loader.getLabels();
    }


}

Other Java examples (source code examples)

Here is a short list of links related to this Java CifarDataSetIterator.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.