alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (CanovaDataSetFunction.java)

This example Java source code file (CanovaDataSetFunction.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

arraylist, canovadatasetfunction, dataset, datasetpreprocessor, illegalstateexception, indarray, invalid, list, ndarraywritable, override, serializable, unsupportedoperationexception, util, writableconverter, writableconverterexception

The CanovaDataSetFunction.java Java example source code

package org.deeplearning4j.spark.canova;

import org.apache.spark.api.java.function.Function;
import org.canova.api.io.WritableConverter;
import org.canova.api.io.converters.WritableConverterException;
import org.canova.api.writable.Writable;
import org.canova.common.data.NDArrayWritable;
import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

/**Map {@code Collection<Writable>} objects (out of a canova-spark record reader function) to DataSet objects for Spark training.
 * Analogous to {@link RecordReaderDataSetIterator}, but in the context of Spark.
 * @author Alex Black
 */
public class CanovaDataSetFunction implements Function<Collection, Serializable {

    private final int labelIndex;
    private final int numPossibleLabels;
    private final boolean regression;
    private final DataSetPreProcessor preProcessor;
    private final WritableConverter converter;
    protected int batchSize = -1;

    public CanovaDataSetFunction(int labelIndex, int numPossibleLabels, boolean regression){
        this(labelIndex, numPossibleLabels, regression, null, null);
    }

    /**
     * @param labelIndex Index of the label column
     * @param numPossibleLabels Number of classes for classification  (not used if regression = true)
     * @param regression False for classification, true for regression
     * @param preProcessor DataSetPreprocessor (may be null)
     * @param converter WritableConverter (may be null)
     */
    public CanovaDataSetFunction(int labelIndex, int numPossibleLabels, boolean regression,
                                 DataSetPreProcessor preProcessor, WritableConverter converter){
        this.labelIndex = labelIndex;
        this.numPossibleLabels = numPossibleLabels;
        this.regression = regression;
        this.preProcessor = preProcessor;
        this.converter = converter;
    }

    @Override
    public DataSet call(Collection<Writable> writables) throws Exception {
        List<Writable> list;
        if(writables instanceof List) list = (List<Writable>)writables;
        else list = new ArrayList<>(writables);

        //allow people to specify label index as -1 and infer the last possible label
        int labelIndex = this.labelIndex;
        if (numPossibleLabels >= 1 && labelIndex < 0) {
            labelIndex = list.size() - 1;
        }

        INDArray label = null;
        INDArray featureVector = null;
        int featureCount = 0;
        for (int j = 0; j < list.size(); j++) {
            Writable current = list.get(j);
            if(converter != null) current = converter.convert(current);
            if (labelIndex >= 0 && j == labelIndex) {
                //Current value is the label
                if (converter != null) {
                    try {
                        current = converter.convert(current);
                    } catch (WritableConverterException e) {
                        e.printStackTrace();
                    }
                }
                if (numPossibleLabels < 1)
                    throw new IllegalStateException("Number of possible labels invalid, must be >= 1");

                if (regression) {
                    label = Nd4j.scalar(current.toDouble());
                } else {
                    //Convert to one-hot vector for
                    int curr = current.toInt();
                    if (curr >= numPossibleLabels)
                        throw new IllegalStateException("Invalid input: class label is " + curr
                            + " with numPossibleLables = " + numPossibleLabels + " (class label must be 0 <= labelIdx < numPossibleLabels)");
                    label = FeatureUtil.toOutcomeVector(curr, numPossibleLabels);
                }
            } else {
                //Current value is not the label
                try {
                    double value = current.toDouble();
                    if (featureVector == null) {
                        featureVector = Nd4j.create(labelIndex >= 0 ? list.size() - 1 : list.size());
                    }
                    featureVector.putScalar(featureCount++, value);
                } catch (UnsupportedOperationException e) {
                    // This isn't a scalar, so check if we got an array already
                    if (current instanceof NDArrayWritable) {
                        assert featureVector == null;
                        featureVector = ((NDArrayWritable)current).get();
                    } else {
                        throw e;
                    }
                }
            }
        }

        DataSet ds = new DataSet(featureVector, (labelIndex >= 0 ? label : featureVector) );
        if(preProcessor != null) preProcessor.preProcess(ds);
        return ds;
    }
}

Other Java examples (source code examples)

Here is a short list of links related to this Java CanovaDataSetFunction.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.