|
Java example source code file (BaseLayerTest.java)
The BaseLayerTest.java Java example source codepackage org.deeplearning4j.nn.layers; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.layers.factory.LayerFactories; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.Before; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; /** * Created by nyghtowl on 11/15/15. */ public class BaseLayerTest { protected INDArray weight = Nd4j.create(new double[]{0.10, -0.20, -0.15, 0.05}, new int[]{2, 2}); protected INDArray bias = Nd4j.create(new double[] {0.5,0.5}, new int[] {1,2}); protected Map<String, INDArray> paramTable; @Before public void doBefore(){ paramTable = new HashMap<>(); paramTable.put("W", weight); paramTable.put("b", bias); } @Test public void testSetExistingParamsConvolutionSingleLayer() { Layer layer = configureSingleLayer(); assertNotEquals(paramTable, layer.paramTable()); layer.setParamTable(paramTable); assertEquals(paramTable, layer.paramTable()); } @Test public void testSetExistingParamsDenseMultiLayer() { MultiLayerNetwork net = configureMultiLayer(); for(Layer layer: net.getLayers()) { assertNotEquals(paramTable, layer.paramTable()); layer.setParamTable(paramTable); assertEquals(paramTable, layer.paramTable()); } } public Layer configureSingleLayer(){ int nIn = 2; int nOut = 2; NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() .layer(new ConvolutionLayer.Builder() .nIn(nIn).nOut(nOut) .build()) .build(); int numParams = LayerFactories.getFactory(conf).initializer().numParams(conf,true); INDArray params = Nd4j.create(1, numParams); return LayerFactories.getFactory(conf).create(conf, null, 0, params, true); } public MultiLayerNetwork configureMultiLayer(){ int nIn = 2; int nOut = 2; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .list() .layer(0, new DenseLayer.Builder() .nIn(nIn).nOut(nOut) .build()) .layer(1, new OutputLayer.Builder() .nIn(nIn).nOut(nOut) .build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); return net; } } Other Java examples (source code examples)Here is a short list of links related to this Java BaseLayerTest.java source code file: |
... this post is sponsored by my books ... | |
#1 New Release! |
FP Best Seller |
Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.
A percentage of advertising revenue from
pages under the /java/jwarehouse
URI on this website is
paid back to open source projects.