alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (DeepWalk.java)

This example Java source code file (DeepWalk.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

atomiclong, builder, creating, deepwalk, executorservice, graphwalkiterator, initialization, learningcallable, list, override, runtimeexception, status_update_frequency, thread, threading, threads, unsupportedoperationexception, util

The DeepWalk.java Java example source code

package org.deeplearning4j.graph.models.deepwalk;

import lombok.AllArgsConstructor;
import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.api.NoEdgeHandling;
import org.deeplearning4j.graph.api.Vertex;
import org.deeplearning4j.graph.api.IVertexSequence;
import org.deeplearning4j.graph.iterator.GraphWalkIterator;
import org.deeplearning4j.graph.iterator.parallel.GraphWalkIteratorProvider;
import org.deeplearning4j.graph.iterator.parallel.RandomWalkGraphIteratorProvider;
import org.deeplearning4j.graph.models.GraphVectors;
import org.deeplearning4j.graph.models.embeddings.GraphVectorLookupTable;
import org.deeplearning4j.graph.models.embeddings.GraphVectorsImpl;
import org.deeplearning4j.graph.models.embeddings.InMemoryGraphLookupTable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;

/**Implementation of the DeepWalk graph vectorization model, based on the paper
 * <i>DeepWalk: Online Learning of Social Representations by Perozzi, Al-Rfou & Skiena (2014),
 * <a href="http://arxiv.org/abs/1403.6652">http://arxiv.org/abs/1403.6652
* Similar to word2vec in nature, DeepWalk is an unsupervised learning algorithm that learns a vector representation * of each vertex in a graph. Vector representations are learned using walks (usually random walks) on the vertices in * the graph.<br> * Once learned, these vector representations can then be used for purposes such as classification, clustering, similarity * search, etc on the graph<br> * @author Alex Black */ public class DeepWalk<V,E> extends GraphVectorsImpl { public static final int STATUS_UPDATE_FREQUENCY = 1000; private Logger log = LoggerFactory.getLogger(DeepWalk.class); private int vectorSize; private int windowSize; private double learningRate; private boolean initCalled = false; private long seed; private ExecutorService executorService; private int nThreads = Runtime.getRuntime().availableProcessors(); private transient AtomicLong walkCounter = new AtomicLong(0); public DeepWalk(){ } public int getVectorSize(){ return vectorSize; } public int getWindowSize(){ return windowSize; } public double getLearningRate(){ return learningRate; } public void setLearningRate(double learningRate){ this.learningRate = learningRate; if(lookupTable != null) lookupTable.setLearningRate(learningRate); } /** Initialize the DeepWalk model with a given graph. */ public void initialize(IGraph<V,E> graph){ int nVertices = graph.numVertices(); int[] degrees = new int[nVertices]; for( int i=0; i<nVertices; i++ ) degrees[i] = graph.getVertexDegree(i); initialize(degrees); } /** Initialize the DeepWalk model with a list of vertex degrees for a graph.<br> * Specifically, graphVertexDegrees[i] represents the vertex degree of the ith vertex<br> * vertex degrees are used to construct a binary (Huffman) tree, which is in turn used in * the hierarchical softmax implementation * @param graphVertexDegrees degrees of each vertex */ public void initialize(int[] graphVertexDegrees){ log.info("Initializing: Creating Huffman tree and lookup table..."); GraphHuffman gh = new GraphHuffman(graphVertexDegrees.length); gh.buildTree(graphVertexDegrees); lookupTable = new InMemoryGraphLookupTable(graphVertexDegrees.length,vectorSize,gh,learningRate); initCalled = true; log.info("Initialization complete"); } /** Fit the model, in parallel. * This creates a set of GraphWalkIterators, which are then distributed one to each thread * @param graph Graph to fit * @param walkLength Length of rangom walks to generate */ public void fit( IGraph<V,E> graph, int walkLength ){ if(!initCalled) initialize(graph); //First: create iterators, one for each thread GraphWalkIteratorProvider<V> iteratorProvider = new RandomWalkGraphIteratorProvider(graph,walkLength,seed, NoEdgeHandling.SELF_LOOP_ON_DISCONNECTED); fit(iteratorProvider); } /** Fit the model, in parallel, using a given GraphWalkIteratorProvider.<br> * This object is used to generate multiple GraphWalkIterators, which can then be distributed to each thread * to do in parallel<br> * Note that {@link #fit(IGraph, int)} will be more convenient in many cases<br> * Note that {@link #initialize(IGraph)} or {@link #initialize(int[])} <em>must be called first. * @param iteratorProvider GraphWalkIteratorProvider * @see #fit(IGraph, int) */ public void fit(GraphWalkIteratorProvider<V> iteratorProvider){ if(!initCalled) throw new UnsupportedOperationException("DeepWalk not initialized (call initialize before fit)"); List<GraphWalkIterator iteratorList = iteratorProvider.getGraphWalkIterators(nThreads); executorService = Executors.newFixedThreadPool(nThreads, new ThreadFactory() { @Override public Thread newThread(Runnable r) { Thread t = new Thread(r); t.setDaemon(true); return t; } }); List<Future list = new ArrayList<>(iteratorList.size()); //log.info("Fitting Graph with {} threads", Math.max(nThreads,iteratorList.size())); for( GraphWalkIterator<V> iter : iteratorList ){ LearningCallable c = new LearningCallable(iter); list.add(executorService.submit(c)); } executorService.shutdown(); try{ executorService.awaitTermination(999, TimeUnit.DAYS); }catch(InterruptedException e){ throw new RuntimeException("ExecutorService interrupted",e); } //Don't need to block on futures to get a value out, but we want to re-throw any exceptions encountered for(Future<Void> f : list){ try{ f.get(); }catch(Exception e){ throw new RuntimeException(e); } } } /**Fit the DeepWalk model <b>using a single thread using a given GraphWalkIterator. If parallel fitting is required, * {@link #fit(IGraph, int)} or {@link #fit(GraphWalkIteratorProvider)} should be used.<br> * Note that {@link #initialize(IGraph)} or {@link #initialize(int[])} <em>must be called first. * * @param iterator iterator for graph walks */ public void fit(GraphWalkIterator<V> iterator){ if(!initCalled) throw new UnsupportedOperationException("DeepWalk not initialized (call initialize before fit)"); int walkLength = iterator.walkLength(); while(iterator.hasNext()){ IVertexSequence<V> sequence = iterator.next(); //Skipgram model: int[] walk = new int[walkLength+1]; int i=0; while(sequence.hasNext()) walk[i++] = sequence.next().vertexID(); skipGram(walk); long iter = walkCounter.incrementAndGet(); if(iter % STATUS_UPDATE_FREQUENCY == 0 ){ log.info("Processed {} random walks on graph",iter); } } } private void skipGram(int[] walk){ for(int mid = windowSize; mid < walk.length-windowSize; mid++ ){ for( int pos=mid-windowSize; pos<=mid+windowSize; pos++ ){ if(pos == mid) continue; //pair of vertices: walk[mid] -> walk[pos] lookupTable.iterate(walk[mid],walk[pos]); } } } public GraphVectorLookupTable lookupTable(){ return lookupTable; } public static class Builder<V,E> { private int vectorSize = 100; private long seed = System.currentTimeMillis(); private double learningRate = 0.01; private int windowSize = 2; /** Sets the size of the vectors to be learned for each vertex in the graph */ public Builder<V,E> vectorSize(int vectorSize){ this.vectorSize = vectorSize; return this; } /** Set the learning rate */ public Builder<V,E> learningRate(double learningRate){ this.learningRate = learningRate; return this; } /** Sets the window size used in skipgram model */ public Builder<V,E> windowSize(int windowSize){ this.windowSize = windowSize; return this; } /** Seed for random number generation (used for repeatability). * Note however that parallel/async gradient descent might result in behaviour that * is not repeatable, in spite of setting seed */ public Builder<V,E> seed(long seed){ this.seed = seed; return this; } public DeepWalk<V,E> build(){ DeepWalk<V,E> dw = new DeepWalk<>(); dw.vectorSize = vectorSize; dw.windowSize = windowSize; dw.learningRate = learningRate; dw.seed = seed; return dw; } } @AllArgsConstructor private class LearningCallable implements Callable<Void> { private final GraphWalkIterator<V> iterator; @Override public Void call() throws Exception { fit(iterator); return null; } } }

Other Java examples (source code examples)

Here is a short list of links related to this Java DeepWalk.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.