|
The Glove.java Java example source code
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.spark.models.embeddings.glove;
import org.apache.commons.math3.util.FastMath;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.berkeley.CounterMap;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCalculator;
import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCounts;
import org.deeplearning4j.spark.text.functions.TextPipeline;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.AdaGrad;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import java.io.Serializable;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;
import static org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecVariables.*;
/**
* Spark glove
*
* @author Adam Gibson
*/
public class Glove implements Serializable {
private Broadcast<VocabCache vocabCacheBroadcast;
private String tokenizerFactoryClazz = DefaultTokenizerFactory.class.getName();
private boolean symmetric = true;
private int windowSize = 15;
private int iterations = 300;
private static Logger log = LoggerFactory.getLogger(Glove.class);
/**
*
* @param tokenizerFactoryClazz the fully qualified class name of the tokenizer
* @param symmetric whether the co occurrence counts should be symmetric
* @param windowSize the window size for co occurrence
* @param iterations the number of iterations
*/
public Glove(String tokenizerFactoryClazz, boolean symmetric, int windowSize, int iterations) {
this.tokenizerFactoryClazz = tokenizerFactoryClazz;
this.symmetric = symmetric;
this.windowSize = windowSize;
this.iterations = iterations;
}
/**
*
* @param symmetric whether the co occurrence counts should be symmetric
* @param windowSize the window size for co occurrence
* @param iterations the number of iterations
*/
public Glove(boolean symmetric, int windowSize, int iterations) {
this.symmetric = symmetric;
this.windowSize = windowSize;
this.iterations = iterations;
}
private Pair<INDArray,Double> update(
AdaGrad weightAdaGrad
,AdaGrad biasAdaGrad
,INDArray syn0
,INDArray bias
,VocabWord w1
,INDArray wordVector
,INDArray contextVector
,double gradient) {
//gradient for word vectors
INDArray grad1 = contextVector.mul(gradient);
INDArray update = weightAdaGrad.getGradient(grad1,w1.getIndex(),syn0.shape());
wordVector.subi(update);
double w1Bias = bias.getDouble(w1.getIndex());
double biasGradient = biasAdaGrad.getGradient(gradient,w1.getIndex(),bias.shape());
double update2 = w1Bias - biasGradient;
bias.putScalar(w1.getIndex(),bias.getDouble(w1.getIndex()) - update2);
return new Pair<>(update,update2);
}
/**
* Train on the corpus
* @param rdd the rdd to train
* @return the vocab and weights
*/
public Pair<VocabCache train(JavaRDD rdd) throws Exception{
// Each `train()` can use different parameters
final JavaSparkContext sc = new JavaSparkContext(rdd.context());
final SparkConf conf = sc.getConf();
final int vectorLength = assignVar(VECTOR_LENGTH, conf, Integer.class);
final boolean useAdaGrad = assignVar(ADAGRAD, conf, Boolean.class);
final double negative = assignVar(NEGATIVE, conf, Double.class);
final int numWords = assignVar(NUM_WORDS, conf, Integer.class);
final int window = assignVar(WINDOW, conf, Integer.class);
final double alpha = assignVar(ALPHA, conf, Double.class);
final double minAlpha = assignVar(MIN_ALPHA, conf, Double.class);
final int iterations = assignVar(ITERATIONS, conf, Integer.class);
final int nGrams = assignVar(N_GRAMS, conf, Integer.class);
final String tokenizer = assignVar(TOKENIZER, conf, String.class);
final String tokenPreprocessor = assignVar(TOKEN_PREPROCESSOR, conf, String.class);
final boolean removeStop = assignVar(REMOVE_STOPWORDS, conf, Boolean.class);
Map<String, Object> tokenizerVarMap = new HashMap() {{
put("numWords", numWords);
put("nGrams", nGrams);
put("tokenizer", tokenizer);
put("tokenPreprocessor", tokenPreprocessor);
put("removeStop", removeStop);
}};
Broadcast<Map broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
TextPipeline pipeline = new TextPipeline(rdd, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
// Get total word count
Long totalWordCount = pipeline.getTotalWordCount();
VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
JavaRDD<Pair> sentenceWordsCountRDD = pipeline.getSentenceWordsCountRDD();
final Pair<VocabCache vocabAndNumWords = new Pair<>(vocabCache, totalWordCount);
vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst());
final GloveWeightLookupTable gloveWeightLookupTable = new GloveWeightLookupTable.Builder()
.cache(vocabAndNumWords.getFirst()).lr(conf.getDouble(GlovePerformer.ALPHA,0.01))
.maxCount(conf.getDouble(GlovePerformer.MAX_COUNT,100)).vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH,300))
.xMax(conf.getDouble(GlovePerformer.X_MAX,0.75)).build();
gloveWeightLookupTable.resetWeights();
gloveWeightLookupTable.getBiasAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().rows());
gloveWeightLookupTable.getWeightAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().shape());
log.info("Created lookup table of size " + Arrays.toString(gloveWeightLookupTable.getSyn0().shape()));
CounterMap<String,String> coOccurrenceCounts = sentenceWordsCountRDD.map(new CoOccurrenceCalculator(symmetric,vocabCacheBroadcast, windowSize))
.fold(new CounterMap<String, String>(), new CoOccurrenceCounts());
Iterator<Pair pair2 = coOccurrenceCounts.getPairIterator();
List<Triple counts = new ArrayList<>();
while(pair2.hasNext()) {
Pair<String,String> next = pair2.next();
if(coOccurrenceCounts.getCount(next.getFirst(),next.getSecond()) > gloveWeightLookupTable.getMaxCount()) {
coOccurrenceCounts.setCount(next.getFirst(),next.getSecond(),gloveWeightLookupTable.getMaxCount());
}
counts.add(new Triple<>(next.getFirst(),next.getSecond(),coOccurrenceCounts.getCount(next.getFirst(),next.getSecond())));
}
log.info("Calculated co occurrences");
JavaRDD<Triple parallel = sc.parallelize(counts);
JavaPairRDD<String, Tuple2 pairs = parallel.mapToPair(new PairFunction, String, Tuple2>() {
@Override
public Tuple2<String, Tuple2 call(Triple stringStringDoubleTriple) throws Exception {
return new Tuple2<>(stringStringDoubleTriple.getFirst(),new Tuple2<>(stringStringDoubleTriple.getSecond(),stringStringDoubleTriple.getThird()));
}
});
JavaPairRDD<VocabWord,Tuple2 pairsVocab = pairs.mapToPair(new PairFunction>, VocabWord, Tuple2>() {
@Override
public Tuple2<VocabWord, Tuple2 call(Tuple2> stringTuple2Tuple2) throws Exception {
VocabWord w1 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._1());
VocabWord w2 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._2()._1());
return new Tuple2<>(w1, new Tuple2<>(w2, stringTuple2Tuple2._2()._2()));
}
});
for(int i = 0; i < iterations; i++) {
JavaRDD<GloveChange> change = pairsVocab.map(new Function>, GloveChange>() {
@Override
public GloveChange call(Tuple2<VocabWord, Tuple2 vocabWordTuple2Tuple2) throws Exception {
VocabWord w1 = vocabWordTuple2Tuple2._1();
VocabWord w2 = vocabWordTuple2Tuple2._2()._1();
INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex());
INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex());
INDArray bias = gloveWeightLookupTable.getBias();
double score = vocabWordTuple2Tuple2._2()._2();
double xMax = gloveWeightLookupTable.getxMax();
double maxCount = gloveWeightLookupTable.getMaxCount();
//w1 * w2 + bias
double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());
double weight = FastMath.pow(Math.min(1.0, (score / maxCount)), xMax);
double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
if (Double.isNaN(fDiff))
fDiff = Nd4j.EPS_THRESHOLD;
//amount of change
double gradient = fDiff;
Pair<INDArray, Double> w1Update = update(
gloveWeightLookupTable.getWeightAdaGrad()
, gloveWeightLookupTable.getBiasAdaGrad()
, gloveWeightLookupTable.getSyn0()
, gloveWeightLookupTable.getBias()
, w1, w1Vector, w2Vector, gradient);
Pair<INDArray, Double> w2Update = update(
gloveWeightLookupTable.getWeightAdaGrad()
, gloveWeightLookupTable.getBiasAdaGrad()
, gloveWeightLookupTable.getSyn0()
, gloveWeightLookupTable.getBias()
, w2, w2Vector, w1Vector, gradient);
return new GloveChange(
w1, w2
, w1Update.getFirst(), w2Update.getFirst()
, w1Update.getSecond(), w2Update.getSecond()
, fDiff
, gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w1.getIndex()),
gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w2.getIndex())
, gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w2.getIndex())
, gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w1.getIndex()));
}
});
List<GloveChange> gloveChanges = change.collect();
double error = 0.0;
for(GloveChange change2 : gloveChanges) {
change2.apply(gloveWeightLookupTable);
error += change2.getError();
}
List l = pairsVocab.collect();
Collections.shuffle(l);
pairsVocab = sc.parallelizePairs(l);
log.info("Error at iteration " + i + " was " + error);
}
return new Pair<>(vocabAndNumWords.getFirst(),gloveWeightLookupTable);
}
}
Other Java examples (source code examples)
Here is a short list of links related to this Java Glove.java source code file:
|