|
Java example source code file (BaseGraphVertex.java)
The BaseGraphVertex.java Java example source code/* * * * Copyright 2016 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */ package org.deeplearning4j.nn.graph.vertex; import lombok.Data; import org.deeplearning4j.nn.graph.ComputationGraph; import org.nd4j.linalg.api.ndarray.INDArray; import java.io.Serializable; import java.util.Arrays; /** BaseGraphVertex defines a set of common functionality for GraphVertex instances. */ @Data public abstract class BaseGraphVertex implements GraphVertex { protected ComputationGraph graph; protected String vertexName; /** The index of this vertex */ protected int vertexIndex; /**A representation of the vertices that are inputs to this vertex (inputs during forward pass) * Specifically, if inputVertices[X].getVertexIndex() = Y, and inputVertices[X].getVertexEdgeNumber() = Z * then the Zth output of vertex Y is the Xth input to this vertex */ protected VertexIndices[] inputVertices; /**A representation of the vertices that this vertex is connected to (outputs duing forward pass) * Specifically, if outputVertices[X].getVertexIndex() = Y, and outputVertices[X].getVertexEdgeNumber() = Z * then the output of this vertex (there is only one output) is connected to the Zth input of vertex Y */ protected VertexIndices[] outputVertices; protected INDArray[] inputs; protected INDArray[] epsilons; protected BaseGraphVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices){ this.graph = graph; this.vertexName = name; this.vertexIndex = vertexIndex; this.inputVertices = inputVertices; this.outputVertices = outputVertices; this.inputs = new INDArray[(inputVertices != null ? inputVertices.length : 0)]; this.epsilons = new INDArray[(outputVertices != null ? outputVertices.length : 0)]; } @Override public String getVertexName(){ return vertexName; } @Override public int getVertexIndex(){ return vertexIndex; } @Override public int getNumInputArrays(){ return (inputVertices == null ? 0 : inputVertices.length); } @Override public int getNumOutputConnections(){ return (outputVertices == null ? 0 : outputVertices.length); } /**A representation of the vertices that are inputs to this vertex (inputs duing forward pass)<br> * Specifically, if inputVertices[X].getVertexIndex() = Y, and inputVertices[X].getVertexEdgeNumber() = Z * then the Zth output of vertex Y is the Xth input to this vertex */ @Override public VertexIndices[] getInputVertices(){ return inputVertices; } @Override public void setInputVertices(VertexIndices[] inputVertices){ this.inputVertices = inputVertices; this.inputs = new INDArray[(inputVertices != null ? inputVertices.length : 0)]; } /**A representation of the vertices that this vertex is connected to (outputs duing forward pass) * Specifically, if outputVertices[X].getVertexIndex() = Y, and outputVertices[X].getVertexEdgeNumber() = Z * then the Xth output of this vertex is connected to the Zth input of vertex Y */ @Override public VertexIndices[] getOutputVertices(){ return outputVertices; } @Override public void setOutputVertices(VertexIndices[] outputVertices){ this.outputVertices = outputVertices; this.epsilons = new INDArray[(outputVertices != null ? outputVertices.length : 0)]; } @Override public boolean isInputVertex(){ return false; } @Override public void setInput(int inputNumber, INDArray input){ if(inputNumber >= getNumInputArrays()) throw new IllegalArgumentException("Invalid input number"); inputs[inputNumber] = input; } @Override public void setError(int errorNumber, INDArray error){ if(errorNumber >= getNumOutputConnections() ){ throw new IllegalArgumentException("Invalid error number: " + errorNumber + ", numOutputEdges = " + (outputVertices != null ? outputVertices.length : 0) ); } epsilons[errorNumber] = error; } @Override public void clear(){ for (int i = 0; i < inputs.length; i++) inputs[i] = null; for (int i = 0; i < epsilons.length; i++) epsilons[i] = null; } @Override public boolean canDoForward(){ for (INDArray input : inputs) if (input == null) return false; return true; } @Override public boolean canDoBackward(){ for (INDArray input : inputs) if (input == null) return false; for (INDArray epsilon : epsilons) if (epsilon == null) return false; return true; } @Override public INDArray[] getErrors(){ return epsilons; } @Override public void setErrors(INDArray... errors){ this.epsilons = errors; } @Override public abstract String toString(); } Other Java examples (source code examples)Here is a short list of links related to this Java BaseGraphVertex.java source code file: |
... this post is sponsored by my books ... | |
#1 New Release! |
FP Best Seller |
Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.
A percentage of advertising revenue from
pages under the /java/jwarehouse
URI on this website is
paid back to open source projects.