|
Java example source code file (ModelSerializerTest.java)
The ModelSerializerTest.java Java example source codepackage org.deeplearning4j.util; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Test; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.File; import java.io.FileInputStream; import static org.junit.Assert.assertEquals; /** * @author raver119@gmail.com */ public class ModelSerializerTest { @Test public void testWriteMLNModel() throws Exception { int nIn = 5; int nOut = 6; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(12345) .regularization(true).l1(0.01).l2(0.01) .learningRate(0.1).activation("tanh").weightInit(WeightInit.XAVIER) .list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()) .layer(2, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); File tempFile = File.createTempFile("tsfs", "fdfsdf"); tempFile.deleteOnExit(); ModelSerializer.writeModel(net, tempFile, true); MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile); assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater(), network.getUpdater()); } @Test public void testWriteMlnModelInputStream() throws Exception { int nIn = 5; int nOut = 6; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(12345) .regularization(true).l1(0.01).l2(0.01) .learningRate(0.1).activation("tanh").weightInit(WeightInit.XAVIER) .list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()) .layer(2, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); File tempFile = File.createTempFile("tsfs", "fdfsdf"); tempFile.deleteOnExit(); ModelSerializer.writeModel(net, tempFile, true); MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile); assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater(), network.getUpdater()); } @Test public void testWriteCGModel() throws Exception { ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .learningRate(0.1) .graphBuilder() .addInputs("in") .addLayer("dense",new DenseLayer.Builder().nIn(4).nOut(2).build(),"in") .addLayer("out",new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(),"dense") .setOutputs("out") .pretrain(false).backprop(true) .build(); ComputationGraph cg = new ComputationGraph(config); cg.init(); File tempFile = File.createTempFile("tsfs", "fdfsdf"); tempFile.deleteOnExit(); ModelSerializer.writeModel(cg, tempFile, true); ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile); assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson()); assertEquals(cg.params(), network.params()); // updater breaks equality? huh? //assertEquals(cg.getUpdater(), network.getUpdater()); } @Test public void testWriteCGModelInputStream() throws Exception { ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .learningRate(0.1) .graphBuilder() .addInputs("in") .addLayer("dense",new DenseLayer.Builder().nIn(4).nOut(2).build(),"in") .addLayer("out",new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(),"dense") .setOutputs("out") .pretrain(false).backprop(true) .build(); ComputationGraph cg = new ComputationGraph(config); cg.init(); File tempFile = File.createTempFile("tsfs", "fdfsdf"); tempFile.deleteOnExit(); ModelSerializer.writeModel(cg, tempFile, true); FileInputStream fis = new FileInputStream(tempFile); ComputationGraph network = ModelSerializer.restoreComputationGraph(fis); assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson()); assertEquals(cg.params(), network.params()); // updater breaks equality? huh? //assertEquals(cg.getUpdater(), network.getUpdater()); } } Other Java examples (source code examples)Here is a short list of links related to this Java ModelSerializerTest.java source code file: |
... this post is sponsored by my books ... | |
#1 New Release! |
FP Best Seller |
Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.
A percentage of advertising revenue from
pages under the /java/jwarehouse
URI on this website is
paid back to open source projects.