|
Java example source code file (GraphVectorsImpl.java)
This example Java source code file (GraphVectorsImpl.java) is included in the alvinalexander.com
"Java Source Code
Warehouse" project. The intent of this project is to help you "Learn
Java by Example" TM.
Learn more about this Java project at its project page.
The GraphVectorsImpl.java Java example source code
package org.deeplearning4j.graph.models.embeddings;
import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.api.Vertex;
import org.deeplearning4j.graph.models.GraphVectors;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Collection;
import java.util.Comparator;
import java.util.PriorityQueue;
/** Base implementation for GraphVectors. Used in DeepWalk, and also when loading
* graph vectors from file.
*/
@AllArgsConstructor @NoArgsConstructor
public class GraphVectorsImpl<V,E> implements GraphVectors {
protected IGraph<V,E> graph;
protected GraphVectorLookupTable lookupTable;
@Override
public IGraph<V, E> getGraph() {
return graph;
}
@Override
public int numVertices() {
return lookupTable.getNumVertices();
}
@Override
public int getVectorSize(){
return lookupTable.vectorSize();
}
@Override
public INDArray getVertexVector(Vertex<V> vertex) {
return lookupTable.getVector(vertex.vertexID());
}
@Override
public INDArray getVertexVector(int vertexIdx) {
return lookupTable.getVector(vertexIdx);
}
@Override
public int[] verticesNearest(int vertexIdx, int top) {
INDArray vec = lookupTable.getVector(vertexIdx).dup();
double norm2 = vec.norm2Number().doubleValue();
PriorityQueue<Pair pq = new PriorityQueue>(lookupTable.getNumVertices(),new PairComparator());
Level1 l1 = Nd4j.getBlasWrapper().level1();
for( int i=0; i<numVertices(); i++ ){
if(i == vertexIdx ) continue;
INDArray other = lookupTable.getVector(i);
double cosineSim = l1.dot(vec.length(),1.0,vec,other) / (norm2 * other.norm2Number().doubleValue());
pq.add(new Pair<>(cosineSim,i));
}
int[] out = new int[top];
for( int i=0; i<top; i++ ){
out[i] = pq.remove().getSecond();
}
return out;
}
private static class PairComparator implements Comparator<Pair {
@Override
public int compare(Pair<Double, Integer> o1, Pair o2) {
return -Double.compare(o1.getFirst(),o2.getFirst());
}
}
/**Returns the cosine similarity of the vector representations of two vertices in the graph
* @return Cosine similarity of two vertices
*/
@Override
public double similarity(Vertex<V> vertex1, Vertex vertex2) {
return similarity(vertex1.vertexID(),vertex2.vertexID());
}
/**Returns the cosine similarity of the vector representations of two vertices in the graph,
* given the indices of these verticies
* @return Cosine similarity of two vertices
*/
@Override
public double similarity(int vertexIdx1, int vertexIdx2) {
if(vertexIdx1 == vertexIdx2) return 1.0;
INDArray vector = Transforms.unitVec(getVertexVector(vertexIdx1));
INDArray vector2 = Transforms.unitVec(getVertexVector(vertexIdx2));
return Nd4j.getBlasWrapper().dot(vector, vector2);
}
}
Other Java examples (source code examples)
Here is a short list of links related to this Java GraphVectorsImpl.java source code file:
|