alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (FeedForwardToRnnPreProcessor.java)

This example Java source code file (FeedForwardToRnnPreProcessor.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

clonenotsupportedexception, data, edge, feedforwardtornnpreprocessor, illegalargumentexception, indarray, inputpreprocessor, invalid, noargsconstructor, override, permute, runtimeexception

The FeedForwardToRnnPreProcessor.java Java example source code

package org.deeplearning4j.nn.conf.preprocessor;

import lombok.Data;
import lombok.NoArgsConstructor;

import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;

/**A preprocessor to allow RNN and feed-forward network layers to be used together.<br>
 * For example, DenseLayer -> GravesLSTM<br>
 * This does two things:<br>
 * (a) Reshapes activations out of FeedFoward layer (which is 2D with shape 
 * [miniBatchSize*timeSeriesLength,layerSize]) into 3d activations (with shape
 * [miniBatchSize,layerSize,timeSeriesLength]) suitable to feed into RNN layers.<br>
 * (b) Reshapes 3d epsilons (weights*deltas from RNN layer, with shape
 * [miniBatchSize,layerSize,timeSeriesLength]) into 2d epsilons (with shape
 * [miniBatchSize*timeSeriesLength,layerSize]) for use in feed forward layer
 * @author Alex Black
 * @see RnnToFeedForwardPreProcessor for opposite case (i.e., GravesLSTM -> DenseLayer etc)
 */
@Data @NoArgsConstructor
public class FeedForwardToRnnPreProcessor implements InputPreProcessor {

	@Override
	public INDArray preProcess(INDArray input, int miniBatchSize) {
		//Need to reshape FF activations (2d) activations to 3d (for input into RNN layer)
		if( input.rank() != 2 ) throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2 (i.e., activations for FF layer)");
		if(input.ordering() == 'c') input = Shape.toOffsetZeroCopy(input,'f');

		int[] shape = input.shape();
		INDArray reshaped = input.reshape('f',miniBatchSize,shape[0]/miniBatchSize,shape[1]);
		return reshaped.permute(0,2,1);
	}

	@Override
	public INDArray backprop(INDArray output, int miniBatchSize) {
		//Need to reshape RNN epsilons (3d) to 2d (for use in FF layer backprop calculations)
		if( output.rank() != 3 ) throw new IllegalArgumentException("Invalid input: expect NDArray with rank 3 (i.e., epsilons from RNN layer)");
		if(output.ordering() != 'f') output = output.dup('f');
		int[] shape = output.shape();
		if(shape[0]==1) return output.tensorAlongDimension(0,1,2).permutei(1,0);	//Edge case: miniBatchSize==1
		if(shape[2]==1) return output.tensorAlongDimension(0,1,0);	//Edge case: timeSeriesLength=1
		INDArray permuted = output.permute(0,2,1);	//Permute, so we get correct order after reshaping
		return permuted.reshape('f',shape[0]*shape[2],shape[1]);
	}

	@Override
	public FeedForwardToRnnPreProcessor clone() {
		try {
			FeedForwardToRnnPreProcessor clone = (FeedForwardToRnnPreProcessor) super.clone();
			return clone;
		} catch (CloneNotSupportedException e) {
			throw new RuntimeException(e);
		}
	}
}

Other Java examples (source code examples)

Here is a short list of links related to this Java FeedForwardToRnnPreProcessor.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.