|
Java example source code file (TestSerialization.java)
The TestSerialization.java Java example source codepackage org.deeplearning4j.ui; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.module.SimpleModule; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.layers.factory.LayerFactories; import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder; import org.deeplearning4j.optimize.api.IterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.ui.weights.HistogramIterationListener; import org.deeplearning4j.ui.weights.ModelAndGradient; import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.serde.jackson.VectorDeSerializer; import org.nd4j.serde.jackson.VectorSerializer; import java.util.Arrays; import static org.junit.Assert.*; /** * @author Adam Gibson */ public class TestSerialization { @Test public void testModelSerde() throws Exception { ObjectMapper mapper = getMapper(); NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().momentum(0.9f) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .iterations(1000) .learningRate(1e-1f) .layer(new org.deeplearning4j.nn.conf.layers.AutoEncoder.Builder() .nIn(4).nOut(3) .corruptionLevel(0.6) .sparsity(0.5) .lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).build()) .build(); DataSet d2 = new IrisDataSetIterator(150,150).next(); INDArray input = d2.getFeatureMatrix(); int numParams = LayerFactories.getFactory(conf).initializer().numParams(conf,true); INDArray params = Nd4j.create(1, numParams); AutoEncoder da = LayerFactories.getFactory(conf.getLayer()).create(conf, Arrays.<IterationListener>asList(new ScoreIterationListener(1), new HistogramIterationListener(1)),0, params, true); da.setInput(input); ModelAndGradient g = new ModelAndGradient(da); String json = mapper.writeValueAsString(g); ModelAndGradient read = mapper.readValue(json,ModelAndGradient.class); assertEquals(g,read); } public ObjectMapper getMapper() { ObjectMapper mapper = new ObjectMapper(); SimpleModule nd4j = new SimpleModule("nd4j"); nd4j.addDeserializer(INDArray.class, new VectorDeSerializer()); nd4j.addSerializer(INDArray.class, new VectorSerializer()); mapper.registerModule(nd4j); return mapper; } } Other Java examples (source code examples)Here is a short list of links related to this Java TestSerialization.java source code file: |
... this post is sponsored by my books ... | |
#1 New Release! |
FP Best Seller |
Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.
A percentage of advertising revenue from
pages under the /java/jwarehouse
URI on this website is
paid back to open source projects.