alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (AsyncDataSetIterator.java)

This example Java source code file (AsyncDataSetIterator.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

asyncdatasetiterator, concurrentmodificationexception, integer, interruptedexception, iteratorrunnable, linkedblockingdeque, next, nosuchelementexception, override, queue, runtimeexception, semaphore, something, thread, threading, threads, util

The AsyncDataSetIterator.java Java example source code

package org.deeplearning4j.datasets.iterator;


import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

import java.util.ConcurrentModificationException;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import org.nd4j.linalg.dataset.DataSet;

/**
 *
 * AsyncDataSetIterator takes an existing DataSetIterator and loads one or more DataSet objects
 * from it using a separate thread.
 * For data sets where DataSetIterator.next() is long running (limited by disk read or processing time
 * for example) this may improve performance by loading the next DataSet asynchronously (i.e., while
 * training is continuing on the previous DataSet). Obviously this may use additional memory.<br>
 * Note however that due to asynchronous loading of data, next(int) is not supported.
 *
 * PLEASE NOTE: If used together with CUDA backend, please use it with caution.
 *
 * @author Alex Black
 */
public class AsyncDataSetIterator implements DataSetIterator {
    private final DataSetIterator baseIterator;
    private final BlockingQueue<DataSet> blockingQueue;
    private Thread thread;
    private IteratorRunnable runnable;

    /**
     *
     * Create an AsyncDataSetIterator with a queue size of 1 (i.e., only load a
     * single additional DataSet)
     * @param baseIterator The DataSetIterator to load data from asynchronously
     */
    public AsyncDataSetIterator(DataSetIterator baseIterator){
        this(baseIterator,1);
    }

    /** Create an AsyncDataSetIterator with a specified queue size.
     * @param baseIterator The DataSetIterator to load data from asynchronously
     * @param queueSize size of the queue (max number of elements to load into queue)
     */
    public AsyncDataSetIterator(DataSetIterator baseIterator, int queueSize) {
        if(queueSize <= 0)
            throw new IllegalArgumentException("Queue size must be > 0");
        this.baseIterator = baseIterator;
        blockingQueue = new LinkedBlockingDeque<>(queueSize);
        runnable = new IteratorRunnable();
        thread = new Thread(runnable);

        Integer deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId);

        thread.setDaemon(true);
        thread.start();
    }


    @Override
    public DataSet next(int num) {
        // TODO: why isn't supported? We could just check queue size
        throw new UnsupportedOperationException("Next(int) not supported for AsyncDataSetIterator");
    }

    @Override
    public int totalExamples() {
        return baseIterator.totalExamples();
    }

    @Override
    public int inputColumns() {
        return baseIterator.inputColumns();
    }

    @Override
    public int totalOutcomes() {
        return baseIterator.totalOutcomes();
    }

    @Override
    public synchronized void reset() {
        //Complication here: runnable could be blocking on either baseIterator.next() or blockingQueue.put()
        runnable.killRunnable = true;
        if(runnable.isAlive) thread.interrupt();
        //Wait for runnable to exit, but should only have to wait very short period of time
        //This probably isn't necessary, but is included as a safeguard against race conditions
        try{
            runnable.runCompletedSemaphore.tryAcquire(5, TimeUnit.SECONDS);
        } catch( InterruptedException e ){ }

        //Clear the queue, reset the base iterator, set up a new thread
        blockingQueue.clear();
        baseIterator.reset();
        runnable = new IteratorRunnable();
        thread = new Thread(runnable);

        Integer deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId);

        thread.setDaemon(true);
        thread.start();
    }

    @Override
    public int batch() {
        return baseIterator.batch();
    }

    @Override
    public int cursor() {
        return baseIterator.cursor();
    }

    @Override
    public int numExamples() {
        return baseIterator.numExamples();
    }

    @Override
    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        baseIterator.setPreProcessor(preProcessor);
    }

    @Override
    public List<String> getLabels() {
        return baseIterator.getLabels();
    }

    @Override
    public synchronized  boolean hasNext() {
        if(!blockingQueue.isEmpty())
            return true;

        if(runnable.isAlive) {
            //Empty blocking queue, but runnable is alive
            //(a) runnable is blocking on baseIterator.next()
            //(b) runnable is blocking on blockingQueue.put()
            //either way: there's at least 1 more element to come
            return true;
        } else {
            if(!runnable.killRunnable && runnable.exception != null ) throw runnable.exception;   //Something went wrong
            //Runnable has exited, presumably because it has fetched all elements
            return !blockingQueue.isEmpty();
        }
    }

    @Override
    public synchronized DataSet next() {
        if(!hasNext()) throw new NoSuchElementException();
        //If base iterator threw an unchecked exception: rethrow it now
        if(runnable.exception != null) throw runnable.exception;

        if(!blockingQueue.isEmpty()){
            return blockingQueue.poll();    //non-blocking, but returns null if empty
        }

        //Blocking queue is empty, but more to come
        //Possible reasons:
        // (a) runnable died (already handled - runnable.exception != null)
        // (b) baseIterator.next() hasn't returned yet -> wait for it
        try{
            //Normally: just do blockingQueue.take(), but can't do that here
            //Reason: what if baseIterator.next() throws an exception after
            // blockingQueue.take() is called? In this case, next() will never return
            while(runnable.exception == null ){
                DataSet ds = blockingQueue.poll(5,TimeUnit.SECONDS);
                if(ds != null) return ds;
                if(runnable.killRunnable){
                    //should never happen
                    throw new ConcurrentModificationException("Reset while next() is waiting for element?");
                }
            }
            //exception thrown while getting data from base iterator
            throw runnable.exception;
        }catch(InterruptedException e ){
            throw new RuntimeException(e);  //Shouldn't happen under normal circumstances
        }
    }

    /**
     *
     * Shut down the async data set iterator thread
     * This is not typically necessary if using a single AsyncDataSetIterator
     * (thread is a daemon thread and so shouldn't block the JVM from exiting)
     * Behaviour of next(), hasNext() etc methods after shutdown of async iterator is undefined
     */
    public void shutdown() {
        if(thread.isAlive()) {
            runnable.killRunnable = true;
            thread.interrupt();
        }
    }

    private class IteratorRunnable implements Runnable {
        private volatile boolean killRunnable = false;
        private volatile boolean isAlive = true;
        private volatile RuntimeException exception;
        private Semaphore runCompletedSemaphore = new Semaphore(0);
        @Override
        public void run() {
            try {
                while (!killRunnable && baseIterator.hasNext()) {
                    blockingQueue.put(baseIterator.next());
                }
            } catch( InterruptedException e ){
                //thread.interrupt() while put(DataSet) was blocking
                if(killRunnable) return;
                else exception = new RuntimeException("Runnable interrupted unexpectedly",e); //Something else interrupted
            } catch(RuntimeException e ){
                exception = e;
            } finally {
                isAlive = false;
                runCompletedSemaphore.release();
            }
        }
    }

    @Override
    public void remove() {
    }

}

Other Java examples (source code examples)

Here is a short list of links related to this Java AsyncDataSetIterator.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.