|
Java example source code file (WeightInitUtilTest.java)
The WeightInitUtilTest.java Java example source codepackage org.deeplearning4j.nn.weights; import org.apache.commons.math3.util.FastMath; import org.deeplearning4j.nn.conf.distribution.Distributions; import org.deeplearning4j.nn.conf.distribution.GaussianDistribution; import org.junit.Before; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.distribution.Distribution; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.*; /** * Created by nyghtowl on 11/14/15. */ public class WeightInitUtilTest { protected int[] shape = new int[]{2, 2}; protected Distribution dist = Distributions.createDistribution(new GaussianDistribution(0.0, 0.1)); @Before public void doBefore(){ Nd4j.getRandom().setSeed(123); } @Test public void testDistribution(){ INDArray params = Nd4j.create(shape,'f'); INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.DISTRIBUTION, dist, params); // expected calculation Nd4j.getRandom().setSeed(123); INDArray weightsExpected = dist.sample(shape); assertEquals(weightsExpected, weightsActual); } @Test public void testNormalize(){ INDArray params = Nd4j.create(shape,'f'); INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.NORMALIZED, dist, params); // expected calculation Nd4j.getRandom().setSeed(123); INDArray weightsExpected = Nd4j.rand('f',shape); weightsExpected.subi(0.5).divi(shape[0]); assertEquals(weightsExpected, weightsActual); } @Test public void testRelu(){ INDArray params = Nd4j.create(shape,'f'); INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.RELU, dist,params); // expected calculation Nd4j.getRandom().setSeed(123); INDArray weightsExpected = Nd4j.randn('f',shape).muli(FastMath.sqrt(2.0 / shape[0])); assertEquals(weightsExpected, weightsActual); } @Test public void testSize(){ INDArray params = Nd4j.create(shape,'f'); INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.SIZE, dist, params); // expected calculation Nd4j.getRandom().setSeed(123); double min = -4.0 * Math.sqrt(6.0 / (double) (shape[0] + shape[1])); double max = 4.0 * Math.sqrt(6.0 / (double) (shape[0] + shape[1])); INDArray weightsExpected = Nd4j.rand(shape, Nd4j.getDistributions().createUniform(min,max)); assertEquals(weightsExpected, weightsActual); } @Test public void testUniform(){ INDArray params = Nd4j.create(shape,'f'); INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.UNIFORM, dist, params); // expected calculation Nd4j.getRandom().setSeed(123); double a = 1/(double) shape[0]; INDArray weightsExpected = Nd4j.rand('f',shape).muli(2*a).subi(a); assertEquals(weightsExpected, weightsActual); } @Test public void testVI(){ INDArray params = Nd4j.create(shape,'f'); INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.VI, dist, params); // expected calculation Nd4j.getRandom().setSeed(123); INDArray weightsExpected = Nd4j.rand('f',shape); int numValues = shape[0] + shape[1]; double r = Math.sqrt(6) / Math.sqrt(numValues + 1); weightsExpected.muli(2).muli(r).subi(r); assertEquals(weightsExpected, weightsActual); } @Test public void testXavier(){ INDArray params = Nd4j.create(shape,'f'); INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.XAVIER, dist, params); // expected calculation Nd4j.getRandom().setSeed(123); INDArray weightsExpected = Nd4j.randn('f',shape); weightsExpected.divi(FastMath.sqrt(shape[0] + shape[1])); assertEquals(weightsExpected, weightsActual); } @Test public void testZero(){ INDArray params = Nd4j.create(shape,'f'); INDArray weightsActual = WeightInitUtil.initWeights(shape, WeightInit.ZERO, dist, params); // expected calculation INDArray weightsExpected = Nd4j.create(shape,'f'); assertEquals(weightsExpected, weightsActual); } } Other Java examples (source code examples)Here is a short list of links related to this Java WeightInitUtilTest.java source code file: |
... this post is sponsored by my books ... | |
#1 New Release! |
FP Best Seller |
Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.
A percentage of advertising revenue from
pages under the /java/jwarehouse
URI on this website is
paid back to open source projects.