home | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (TestDeepWalk.java)

This example Java source code file (TestDeepWalk.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

classpathresource, deepwalk, edge, file, graph, graphwalkiteratorprovider, igraph, indarray, ioexception, random, string, stringvertexfactory, test, util, weightedrandomwalkgraphiteratorprovider

The TestDeepWalk.java Java example source code

package org.deeplearning4j.graph.models.deepwalk;

import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.graph.api.Edge;
import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.data.GraphLoader;
import org.deeplearning4j.graph.graph.Graph;
import org.deeplearning4j.graph.iterator.RandomWalkIterator;
import org.deeplearning4j.graph.iterator.GraphWalkIterator;
import org.deeplearning4j.graph.iterator.parallel.GraphWalkIteratorProvider;
import org.deeplearning4j.graph.iterator.parallel.WeightedRandomWalkGraphIteratorProvider;
import org.deeplearning4j.graph.models.GraphVectors;
import org.deeplearning4j.graph.models.loader.GraphVectorSerializer;
import org.deeplearning4j.graph.vertexfactory.StringVertexFactory;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class TestDeepWalk {

    @Test
    public void testBasic() throws IOException{
        //Very basic test. Load graph, build tree, call fit, make sure it doesn't throw any exceptions

        ClassPathResource cpr = new ClassPathResource("testgraph_7vertices.txt");

        Graph<String,String> graph = GraphLoader.loadUndirectedGraphEdgeListFile(cpr.getFile().getAbsolutePath(), 7, ",");

        int vectorSize = 5;
        int windowSize = 2;

        DeepWalk<String,String> deepWalk = new DeepWalk.Builder().learningRate(0.01)
                .vectorSize(vectorSize)
                .windowSize(windowSize)
                .learningRate(0.01)
                .build();
        deepWalk.initialize(graph);

        for( int i=0; i<7; i++ ){
            INDArray vector = deepWalk.getVertexVector(i);
            assertArrayEquals(new int[]{1,vectorSize},vector.shape());
            System.out.println(Arrays.toString(vector.dup().data().asFloat()));
        }

        GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph,8);

        deepWalk.fit(iter);

        for( int t=0; t<5; t++ ) {
            iter.reset();
            deepWalk.fit(iter);
            System.out.println("--------------------");
            for (int i = 0; i < 7; i++) {
                INDArray vector = deepWalk.getVertexVector(i);
                assertArrayEquals(new int[]{1, vectorSize}, vector.shape());
                System.out.println(Arrays.toString(vector.dup().data().asFloat()));
            }
        }
    }

    @Test
    public void testParallel(){

        IGraph<String,String> graph = generateRandomGraph(1000,10);

        int vectorSize = 20;
        int windowSize = 2;

        DeepWalk<String,String> deepWalk = new DeepWalk.Builder().learningRate(0.01)
                .vectorSize(vectorSize)
                .windowSize(windowSize)
                .learningRate(0.01)
                .build();
        deepWalk.initialize(graph);



        deepWalk.fit(graph,8);
    }


    private static Graph<String,String> generateRandomGraph(int nVertices, int nEdgesPerVertex){

        Random r = new Random(12345);

        Graph<String,String> graph = new Graph(nVertices,new StringVertexFactory());
        for( int i=0; i<nVertices; i++ ){
            for( int j=0; j<nEdgesPerVertex; j++ ){
                int to = r.nextInt(nVertices);
                Edge<String> edge = new Edge<>(i,to,i+"--"+to,false);
                graph.addEdge(edge);
            }
        }
        return graph;
    }


    @Test
    public void testVerticesNearest(){

        int nVertices = 20;
        Graph<String,String> graph = generateRandomGraph(nVertices,8);

        int vectorSize = 5;
        int windowSize = 2;
        DeepWalk<String,String> deepWalk = new DeepWalk.Builder().learningRate(0.01)
                .vectorSize(vectorSize)
                .windowSize(windowSize)
                .learningRate(0.01)
                .build();
        deepWalk.initialize(graph);

        deepWalk.fit(graph,10);

        int topN = 5;
        int nearestTo = 4;
        int[] nearest = deepWalk.verticesNearest(nearestTo,topN);
        double[] cosSim = new double[topN];
        double minSimNearest = 1;
        for( int i=0; i<topN; i++ ){
            cosSim[i] = deepWalk.similarity(nearest[i], nearestTo);
            minSimNearest = Math.min(minSimNearest,cosSim[i]);
            if( i > 0 ) assertTrue(cosSim[i] <= cosSim[i-1]);
        }

        for( int i=0; i<nVertices; i++ ){
            if(i == nearestTo) continue;
            boolean skip = false;
            for( int j=0; j<nearest.length; j++ ){
                if(i == nearest[j]){
                    skip = true;
                    continue;
                }
            }
            if(skip) continue;

            double sim = deepWalk.similarity(i,nearestTo);
            System.out.println(i + "\t" + nearestTo + "\t" + sim);
            assertTrue(sim <= minSimNearest);
        }
    }

    @Test
    public void testLoadingSaving() throws IOException{
        String out = FilenameUtils.concat(System.getProperty("java.io.tmpdir"),"dl4jdwtestout.txt");

        int nVertices = 20;
        Graph<String,String> graph = generateRandomGraph(nVertices,8);

        int vectorSize = 5;
        int windowSize = 2;
        DeepWalk<String,String> deepWalk = new DeepWalk.Builder().learningRate(0.01)
                .vectorSize(vectorSize)
                .windowSize(windowSize)
                .learningRate(0.01)
                .build();
        deepWalk.initialize(graph);

        deepWalk.fit(graph, 10);

        GraphVectorSerializer.writeGraphVectors(deepWalk, out);

        GraphVectors<String,String> vectors = (GraphVectors)GraphVectorSerializer.loadTxtVectors(new File(out));

        assertEquals(deepWalk.numVertices(),vectors.numVertices());
        assertEquals(deepWalk.getVectorSize(),vectors.getVectorSize());

        for( int i=0; i<nVertices; i++ ){
            INDArray vecDW = deepWalk.getVertexVector(i);
            INDArray vecLoaded = vectors.getVertexVector(i);

            for( int j=0; j<vectorSize; j++ ){
                double d1 = vecDW.getDouble(j);
                double d2 = vecLoaded.getDouble(j);
                double relError = Math.abs(d1-d2) / (Math.abs(d1) + Math.abs(d2));
                assertTrue(relError < 1e-6);
            }
        }
    }

    @Test
    public void testDeepWalk13Vertices() throws IOException{

        int nVertices = 13;

        ClassPathResource cpr = new ClassPathResource("graph13.txt");
        Graph<String,String> graph = GraphLoader.loadUndirectedGraphEdgeListFile(cpr.getFile().getAbsolutePath(), 13, ",");

        System.out.println(graph);

        Nd4j.getRandom().setSeed(12345);

        int nEpochs = 200;

        //Set up network
        DeepWalk<String,String> deepWalk =
                new DeepWalk.Builder<String,String>()
                .vectorSize(50)
                .windowSize(4)
                .seed(12345)
                .build();

        //Run learning
        for( int i=0; i<nEpochs; i++ ){
            deepWalk.setLearningRate(0.03 / nEpochs * (nEpochs-i));
            deepWalk.fit(graph, 10);
        }

        //Calculate similarity(0,i)
        for( int i=0; i<nVertices; i++ ){
            System.out.println(deepWalk.similarity(0, i));
        }

        for( int i=0; i<nVertices; i++ ) System.out.println(deepWalk.getVertexVector(i));
    }

    @Test
    public void testDeepWalkWeightedParallel() throws IOException {

        //Load graph
        String path = new ClassPathResource("WeightedGraph.txt").getFile().getAbsolutePath();
        int numVertices = 9;
        String delim = ",";
        String[] ignoreLinesStartingWith = new String[]{"//"};  //Comment lines start with "//"
        IGraph<String,Double> graph = GraphLoader.loadWeightedEdgeListFile(path,numVertices,delim,true,ignoreLinesStartingWith);

        //Set up DeepWalk
        int vectorSize = 5;
        int windowSize = 2;
        DeepWalk<String,Double> deepWalk = new DeepWalk.Builder().learningRate(0.01)
                .vectorSize(vectorSize)
                .windowSize(windowSize)
                .learningRate(0.01)
                .build();
        deepWalk.initialize(graph);

        //Can't use the following method here: defaults to unweighted random walk
        //deepWalk.fit(graph, 10);  //Unweighted random walk

        //Create GraphWalkIteratorProvider. The GraphWalkIteratorProvider is used to create multiple GraphWalkIterator objects.
        //Here, it is used to create a GraphWalkIterator, one for each thread
        int walkLength = 5;
        GraphWalkIteratorProvider<String> iteratorProvider = new WeightedRandomWalkGraphIteratorProvider(graph,walkLength);

        //Fit in parallel
        deepWalk.fit(iteratorProvider);

    }
}

Other Java examples (source code examples)

Here is a short list of links related to this Java TestDeepWalk.java source code file:



my book on functional programming

 

new blog posts

 

Copyright 1998-2019 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.