|
The DeepWalk.java Java example source code
package org.deeplearning4j.graph.models.deepwalk;
import lombok.AllArgsConstructor;
import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.api.NoEdgeHandling;
import org.deeplearning4j.graph.api.Vertex;
import org.deeplearning4j.graph.api.IVertexSequence;
import org.deeplearning4j.graph.iterator.GraphWalkIterator;
import org.deeplearning4j.graph.iterator.parallel.GraphWalkIteratorProvider;
import org.deeplearning4j.graph.iterator.parallel.RandomWalkGraphIteratorProvider;
import org.deeplearning4j.graph.models.GraphVectors;
import org.deeplearning4j.graph.models.embeddings.GraphVectorLookupTable;
import org.deeplearning4j.graph.models.embeddings.GraphVectorsImpl;
import org.deeplearning4j.graph.models.embeddings.InMemoryGraphLookupTable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;
/**Implementation of the DeepWalk graph vectorization model, based on the paper
* <i>DeepWalk: Online Learning of Social Representations by Perozzi, Al-Rfou & Skiena (2014),
* <a href="http://arxiv.org/abs/1403.6652">http://arxiv.org/abs/1403.6652
* Similar to word2vec in nature, DeepWalk is an unsupervised learning algorithm that learns a vector representation
* of each vertex in a graph. Vector representations are learned using walks (usually random walks) on the vertices in
* the graph.<br>
* Once learned, these vector representations can then be used for purposes such as classification, clustering, similarity
* search, etc on the graph<br>
* @author Alex Black
*/
public class DeepWalk<V,E> extends GraphVectorsImpl {
public static final int STATUS_UPDATE_FREQUENCY = 1000;
private Logger log = LoggerFactory.getLogger(DeepWalk.class);
private int vectorSize;
private int windowSize;
private double learningRate;
private boolean initCalled = false;
private long seed;
private ExecutorService executorService;
private int nThreads = Runtime.getRuntime().availableProcessors();
private transient AtomicLong walkCounter = new AtomicLong(0);
public DeepWalk(){
}
public int getVectorSize(){
return vectorSize;
}
public int getWindowSize(){
return windowSize;
}
public double getLearningRate(){
return learningRate;
}
public void setLearningRate(double learningRate){
this.learningRate = learningRate;
if(lookupTable != null) lookupTable.setLearningRate(learningRate);
}
/** Initialize the DeepWalk model with a given graph. */
public void initialize(IGraph<V,E> graph){
int nVertices = graph.numVertices();
int[] degrees = new int[nVertices];
for( int i=0; i<nVertices; i++ ) degrees[i] = graph.getVertexDegree(i);
initialize(degrees);
}
/** Initialize the DeepWalk model with a list of vertex degrees for a graph.<br>
* Specifically, graphVertexDegrees[i] represents the vertex degree of the ith vertex<br>
* vertex degrees are used to construct a binary (Huffman) tree, which is in turn used in
* the hierarchical softmax implementation
* @param graphVertexDegrees degrees of each vertex
*/
public void initialize(int[] graphVertexDegrees){
log.info("Initializing: Creating Huffman tree and lookup table...");
GraphHuffman gh = new GraphHuffman(graphVertexDegrees.length);
gh.buildTree(graphVertexDegrees);
lookupTable = new InMemoryGraphLookupTable(graphVertexDegrees.length,vectorSize,gh,learningRate);
initCalled = true;
log.info("Initialization complete");
}
/** Fit the model, in parallel.
* This creates a set of GraphWalkIterators, which are then distributed one to each thread
* @param graph Graph to fit
* @param walkLength Length of rangom walks to generate
*/
public void fit( IGraph<V,E> graph, int walkLength ){
if(!initCalled) initialize(graph);
//First: create iterators, one for each thread
GraphWalkIteratorProvider<V> iteratorProvider = new RandomWalkGraphIteratorProvider(graph,walkLength,seed,
NoEdgeHandling.SELF_LOOP_ON_DISCONNECTED);
fit(iteratorProvider);
}
/** Fit the model, in parallel, using a given GraphWalkIteratorProvider.<br>
* This object is used to generate multiple GraphWalkIterators, which can then be distributed to each thread
* to do in parallel<br>
* Note that {@link #fit(IGraph, int)} will be more convenient in many cases<br>
* Note that {@link #initialize(IGraph)} or {@link #initialize(int[])} <em>must be called first.
* @param iteratorProvider GraphWalkIteratorProvider
* @see #fit(IGraph, int)
*/
public void fit(GraphWalkIteratorProvider<V> iteratorProvider){
if(!initCalled) throw new UnsupportedOperationException("DeepWalk not initialized (call initialize before fit)");
List<GraphWalkIterator iteratorList = iteratorProvider.getGraphWalkIterators(nThreads);
executorService = Executors.newFixedThreadPool(nThreads, new ThreadFactory() {
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r);
t.setDaemon(true);
return t;
}
});
List<Future list = new ArrayList<>(iteratorList.size());
//log.info("Fitting Graph with {} threads", Math.max(nThreads,iteratorList.size()));
for( GraphWalkIterator<V> iter : iteratorList ){
LearningCallable c = new LearningCallable(iter);
list.add(executorService.submit(c));
}
executorService.shutdown();
try{
executorService.awaitTermination(999, TimeUnit.DAYS);
}catch(InterruptedException e){
throw new RuntimeException("ExecutorService interrupted",e);
}
//Don't need to block on futures to get a value out, but we want to re-throw any exceptions encountered
for(Future<Void> f : list){
try{
f.get();
}catch(Exception e){
throw new RuntimeException(e);
}
}
}
/**Fit the DeepWalk model <b>using a single thread using a given GraphWalkIterator. If parallel fitting is required,
* {@link #fit(IGraph, int)} or {@link #fit(GraphWalkIteratorProvider)} should be used.<br>
* Note that {@link #initialize(IGraph)} or {@link #initialize(int[])} <em>must be called first.
*
* @param iterator iterator for graph walks
*/
public void fit(GraphWalkIterator<V> iterator){
if(!initCalled) throw new UnsupportedOperationException("DeepWalk not initialized (call initialize before fit)");
int walkLength = iterator.walkLength();
while(iterator.hasNext()){
IVertexSequence<V> sequence = iterator.next();
//Skipgram model:
int[] walk = new int[walkLength+1];
int i=0;
while(sequence.hasNext()) walk[i++] = sequence.next().vertexID();
skipGram(walk);
long iter = walkCounter.incrementAndGet();
if(iter % STATUS_UPDATE_FREQUENCY == 0 ){
log.info("Processed {} random walks on graph",iter);
}
}
}
private void skipGram(int[] walk){
for(int mid = windowSize; mid < walk.length-windowSize; mid++ ){
for( int pos=mid-windowSize; pos<=mid+windowSize; pos++ ){
if(pos == mid) continue;
//pair of vertices: walk[mid] -> walk[pos]
lookupTable.iterate(walk[mid],walk[pos]);
}
}
}
public GraphVectorLookupTable lookupTable(){
return lookupTable;
}
public static class Builder<V,E> {
private int vectorSize = 100;
private long seed = System.currentTimeMillis();
private double learningRate = 0.01;
private int windowSize = 2;
/** Sets the size of the vectors to be learned for each vertex in the graph */
public Builder<V,E> vectorSize(int vectorSize){
this.vectorSize = vectorSize;
return this;
}
/** Set the learning rate */
public Builder<V,E> learningRate(double learningRate){
this.learningRate = learningRate;
return this;
}
/** Sets the window size used in skipgram model */
public Builder<V,E> windowSize(int windowSize){
this.windowSize = windowSize;
return this;
}
/** Seed for random number generation (used for repeatability).
* Note however that parallel/async gradient descent might result in behaviour that
* is not repeatable, in spite of setting seed
*/
public Builder<V,E> seed(long seed){
this.seed = seed;
return this;
}
public DeepWalk<V,E> build(){
DeepWalk<V,E> dw = new DeepWalk<>();
dw.vectorSize = vectorSize;
dw.windowSize = windowSize;
dw.learningRate = learningRate;
dw.seed = seed;
return dw;
}
}
@AllArgsConstructor
private class LearningCallable implements Callable<Void> {
private final GraphWalkIterator<V> iterator;
@Override
public Void call() throws Exception {
fit(iterator);
return null;
}
}
}
Other Java examples (source code examples)
Here is a short list of links related to this Java DeepWalk.java source code file:
|