home | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (InMemoryGraphLookupTable.java)

This example Java source code file (InMemoryGraphLookupTable.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

binarytree, full, graphvectorlookuptable, indarray, inmemorygraphlookuptable, l-1, level1, max_exp, override, specifically

The InMemoryGraphLookupTable.java Java example source code

package org.deeplearning4j.graph.models.embeddings;

import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.graph.models.BinaryTree;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/** A standard in-memory implementation of a lookup table for vector representations of the vertices in a graph
 * @author Alex Black
 */
public class InMemoryGraphLookupTable implements GraphVectorLookupTable {

    protected int nVertices;
    protected int vectorSize;
    protected BinaryTree tree;
    protected INDArray vertexVectors;   //'input' vectors
    protected INDArray outWeights;      //'output' vectors. Specifically vectors for inner nodes in binary tree
    protected double learningRate;

    protected double[] expTable;
    protected static double MAX_EXP = 6;

    public InMemoryGraphLookupTable(int nVertices, int vectorSize, BinaryTree tree, double learningRate ){
        this.nVertices = nVertices;
        this.vectorSize = vectorSize;
        this.tree = tree;
        this.learningRate = learningRate;
        resetWeights();

        expTable = new double[1000];
        for (int i = 0; i < expTable.length; i++) {
            double tmp =   FastMath.exp((i / (double) expTable.length * 2 - 1) * MAX_EXP);
            expTable[i]  = tmp / (tmp + 1.0);
        }
    }

    public INDArray getVertexVectors(){
        return vertexVectors;
    }

    public INDArray getOutWeights(){
        return outWeights;
    }

    @Override
    public int vectorSize() {
        return vectorSize;
    }

    @Override
    public void resetWeights() {
        this.vertexVectors = Nd4j.rand(nVertices,vectorSize).subi(0.5).divi(vectorSize);
        this.outWeights = Nd4j.rand(nVertices-1,vectorSize).subi(0.5).divi(vectorSize); //Full binary tree with L leaves has L-1 inner nodes
    }

    @Override
    public void iterate(int first, int second) {
        //Get vectors and gradients
        //vecAndGrads[0][0] is vector of vertex(first); vecAndGrads[1][0] is corresponding gradient
        INDArray[][] vecAndGrads = vectorsAndGradients(first,second);

        Level1 l1 = Nd4j.getBlasWrapper().level1();
        for( int i=0; i<vecAndGrads[0].length; i++ ){
            //Update: v = v - lr * gradient
            l1.axpy(vecAndGrads[0][i].length(),-learningRate,vecAndGrads[1][i],vecAndGrads[0][i]);
        }
    }

    /** Returns vertex vector and vector gradients, plus inner node vectors and inner node gradients<br>
     * Specifically, out[0] are vectors, out[1] are gradients for the corresponding vectors<br>
     * out[0][0] is vector for first vertex; out[0][1] is gradient for this vertex vector<br>
     * out[0][i] (i>0) is the inner node vector along path to second vertex; out[1][i] is gradient for inner node vertex<br>
     * This design is used primarily to aid in testing (numerical gradient checks)
     * @param first first (input) vertex index
     * @param second second (output) vertex index
     */
    public INDArray[][] vectorsAndGradients(int first, int second){
        //Input vertex vector gradients are composed of the inner node gradients
        //Get vector for first vertex, as well as code for second:
        INDArray vec = vertexVectors.getRow(first);
        int codeLength = tree.getCodeLength(second);
        long code = tree.getCode(second);
        int[] innerNodesForVertex = tree.getPathInnerNodes(second);

        INDArray[][] out = new INDArray[2][innerNodesForVertex.length+1];

        Level1 l1 = Nd4j.getBlasWrapper().level1();
        INDArray accumError = Nd4j.create(vec.shape());
        for( int i=0; i<codeLength; i++ ){

            //Inner node:
            int innerNodeIdx = innerNodesForVertex[i];
            boolean path = getBit(code, i);  //left or right?

            INDArray innerNodeVector = outWeights.getRow(innerNodeIdx);
            double sigmoidDot = sigmoid(Nd4j.getBlasWrapper().dot(innerNodeVector, vec));



            //Calculate gradient for inner node + accumulate error:
            INDArray innerNodeGrad;
            if(path){
                innerNodeGrad = vec.mul(sigmoidDot-1);
                l1.axpy(vec.length(),sigmoidDot-1,innerNodeVector,accumError);
            } else {
                innerNodeGrad = vec.mul(sigmoidDot);
                l1.axpy(vec.length(),sigmoidDot,innerNodeVector,accumError);
            }

            out[0][i+1] = innerNodeVector;
            out[1][i+1] = innerNodeGrad;
        }

        out[0][0] = vec;
        out[1][0] = accumError;

        return out;
    }

    /** Calculate the probability of the second vertex given the first vertex
     * i.e., P(v_second | v_first)
     * @param first index of the first vertex
     * @param second index of the second vertex
     * @return probability, P(v_second | v_first)
     */
    public double calculateProb(int first, int second){
        //Get vector for first vertex, as well as code for second:
        INDArray vec = vertexVectors.getRow(first);
        int codeLength = tree.getCodeLength(second);
        long code = tree.getCode(second);
        int[] innerNodesForVertex = tree.getPathInnerNodes(second);

        double prob = 1.0;
        for( int i=0; i<codeLength; i++ ){
            boolean path = getBit(code, i);  //left or right?
            //Inner node:
            int innerNodeIdx = innerNodesForVertex[i];
            INDArray nwi = outWeights.getRow(innerNodeIdx);

            double dot = Nd4j.getBlasWrapper().dot(nwi, vec);

            //double sigmoidDot = sigmoid(dot);
            double innerProb = (path ? sigmoid(dot) : sigmoid(-dot));   //prob of going left or right at inner node
            prob *= innerProb;
        }
        return prob;
    }

    /** Calculate score. -log P(v_second | v_first) */
    public double calculateScore(int first, int second){
        //Score is -log P(out|in)
        double prob = calculateProb(first,second);
        return -FastMath.log(prob);
    }

    public BinaryTree getTree(){
        return tree;
    }

    public INDArray getInnerNodeVector(int innerNode){
        return outWeights.getRow(innerNode);
    }

    @Override
    public INDArray getVector(int idx) {
        return vertexVectors.getRow(idx);
    }

    @Override
    public void setLearningRate(double learningRate) {
        this.learningRate = learningRate;
    }

    @Override
    public int getNumVertices() {
        return nVertices;
    }

    private static double sigmoid(double in){
        return 1.0 / (1.0 + FastMath.exp(-in));
    }

    private boolean getBit(long in, int bitNum){
        long mask = 1L << bitNum;
        return (in & mask) != 0L;
    }

    public void setVertexVectors(INDArray vertexVectors){
        this.vertexVectors = vertexVectors;
    }
}

Other Java examples (source code examples)

Here is a short list of links related to this Java InMemoryGraphLookupTable.java source code file:



my book on functional programming

 

new blog posts

 

Copyright 1998-2019 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.