|
Java example source code file (IteratorMultiDataSetIterator.java)
The IteratorMultiDataSetIterator.java Java example source codepackage org.deeplearning4j.datasets.iterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.*; /** * A DataSetIterator that works on an Iterator<DataSet>, combining and splitting the input DataSet objects as * required to get a consistent batch size. * * Typically used in Spark training, but may be used elsewhere. * NOTE: some methods are not supported here. */ public class IteratorMultiDataSetIterator implements MultiDataSetIterator { private final Iterator<MultiDataSet> iterator; private final int batchSize; private final LinkedList<MultiDataSet> queued; //Used when splitting larger examples than we want to return in a batch private MultiDataSetPreProcessor preProcessor; public IteratorMultiDataSetIterator(Iterator<MultiDataSet> iterator, int batchSize){ this.iterator = iterator; this.batchSize = batchSize; this.queued = new LinkedList<>(); } @Override public boolean hasNext() { return !queued.isEmpty() || iterator.hasNext(); } @Override public MultiDataSet next() { return next(batchSize); } @Override public MultiDataSet next(int num) { if(!hasNext()) throw new NoSuchElementException(); List<MultiDataSet> list = new ArrayList<>(); int countSoFar = 0; while((!queued.isEmpty() || iterator.hasNext()) && countSoFar < batchSize){ MultiDataSet next; if(!queued.isEmpty()){ next = queued.removeFirst(); } else { next = iterator.next(); } int nExamples = next.getFeatures(0).size(0); if( countSoFar + nExamples <= batchSize ){ //Add the entire MultiDataSet as-is list.add(next); } else { //Split the MultiDataSet int nFeatures = next.numFeatureArrays(); int nLabels = next.numLabelsArrays(); INDArray[] fToKeep = new INDArray[nFeatures]; INDArray[] lToKeep = new INDArray[nLabels]; INDArray[] fToCache = new INDArray[nFeatures]; INDArray[] lToCache = new INDArray[nLabels]; INDArray[] fMaskToKeep = (next.getFeaturesMaskArrays() != null ? new INDArray[nFeatures] : null); INDArray[] lMaskToKeep = (next.getLabelsMaskArrays() != null ? new INDArray[nLabels] : null); INDArray[] fMaskToCache = (next.getFeaturesMaskArrays() != null ? new INDArray[nFeatures] : null); INDArray[] lMaskToCache = (next.getLabelsMaskArrays() != null ? new INDArray[nLabels] : null); for( int i=0; i<nFeatures; i++ ){ INDArray fi = next.getFeatures(i); INDArray li = next.getFeatures(i); fToKeep[i] = getRange(fi,0,batchSize-countSoFar); fToCache[i] = getRange(fi,batchSize-countSoFar, nExamples); lToKeep[i] = getRange(li,0,batchSize-countSoFar); lToCache[i] = getRange(li,batchSize-countSoFar, nExamples); if(fMaskToKeep != null){ INDArray fmi = next.getFeaturesMaskArray(i); fMaskToKeep[i] = getRange(fmi,0,batchSize-countSoFar); fMaskToCache[i] = getRange(fmi,batchSize-countSoFar, nExamples); } if(lMaskToKeep != null){ INDArray lmi = next.getLabelsMaskArray(i); lMaskToKeep[i] = getRange(lmi,0,batchSize-countSoFar); lMaskToCache[i] = getRange(lmi,batchSize-countSoFar, nExamples); } } MultiDataSet toKeep = new org.nd4j.linalg.dataset.MultiDataSet(fToKeep,lToKeep, fMaskToKeep, lMaskToKeep); MultiDataSet toCache = new org.nd4j.linalg.dataset.MultiDataSet(fToCache,lToCache, fMaskToCache, lMaskToCache); list.add(toKeep); queued.add(toCache); } countSoFar += nExamples; } MultiDataSet out; if(list.size() == 1){ out = list.get(0); } else { out = org.nd4j.linalg.dataset.MultiDataSet.merge(list); } if(preProcessor != null) preProcessor.preProcess(out); return out; } private static INDArray getRange(INDArray arr, int exampleFrom, int exampleToExclusive){ if(arr == null) return null; int rank = arr.rank(); switch(rank){ case 2: return arr.get(NDArrayIndex.interval(exampleFrom, exampleToExclusive), NDArrayIndex.all()); case 3: return arr.get(NDArrayIndex.interval(exampleFrom, exampleToExclusive), NDArrayIndex.all(), NDArrayIndex.all()); case 4: return arr.get(NDArrayIndex.interval(exampleFrom, exampleToExclusive), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()); default: throw new RuntimeException("Invalid rank: " + rank); } } @Override public void reset() { throw new UnsupportedOperationException("Reset not supported"); } @Override public void setPreProcessor(MultiDataSetPreProcessor preProcessor) { this.preProcessor = preProcessor; } @Override public void remove() { throw new UnsupportedOperationException("Not supported"); } } Other Java examples (source code examples)Here is a short list of links related to this Java IteratorMultiDataSetIterator.java source code file: |
... this post is sponsored by my books ... | |
#1 New Release! |
FP Best Seller |
Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.
A percentage of advertising revenue from
pages under the /java/jwarehouse
URI on this website is
paid back to open source projects.