home | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (Word2Vec.java)

This example Java source code file (Word2Vec.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

atomicinteger, broadcast, builder, hashmap, indarray, javardd, list, long, map, max_exp, object, string, util, vectorsconfiguration, word2vec

The Word2Vec.java Java example source code

/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.spark.models.embeddings.word2vec;

import lombok.Getter;
import lombok.NonNull;
import org.apache.commons.math3.util.FastMath;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.text.functions.CountCumSum;
import org.deeplearning4j.spark.text.functions.TextPipeline;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

/**
 * Spark version of word2vec
 *
 * @author Adam Gibson
 * @author raver119@gmail.com
 */
public class Word2Vec extends WordVectorsImpl<VocabWord> implements Serializable  {

    private INDArray trainedSyn1;
    private static Logger log = LoggerFactory.getLogger(Word2Vec.class);
    private int MAX_EXP = 6;
    @Getter private double[] expTable;
    @Getter protected VectorsConfiguration configuration;

    // Input by user only via setters
    private int nGrams = 1;
    private String tokenizer = "org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory";
    private String tokenPreprocessor = "org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor";
    private boolean removeStop = false;
    private long seed = 42L;
    private boolean useUnknown = false;

    // Constructor to take InMemoryLookupCache table from an already trained model
    protected Word2Vec(INDArray trainedSyn1) {
        this.trainedSyn1 = trainedSyn1;
        this.expTable = initExpTable();
    }

    protected Word2Vec() {
        this.expTable = initExpTable();
    }

    protected double[] initExpTable() {
        double[] expTable = new double[100000];
        for (int i = 0; i < expTable.length; i++) {
            double tmp = FastMath.exp((i / (double) expTable.length * 2 - 1) * MAX_EXP);
            expTable[i] = tmp / (tmp + 1.0);
        }
        return expTable;
    }

    public Map<String, Object> getTokenizerVarMap() {
        return new HashMap<String, Object>() {{
            put("numWords", minWordFrequency);
            put("nGrams", nGrams);
            put("tokenizer", tokenizer);
            put("tokenPreprocessor", tokenPreprocessor);
            put("removeStop", removeStop);
            put("stopWords", stopWords);
            put("useUnk", useUnknown);
            put("vectorsConfiguration", configuration);
        }};
    }

    public Map<String, Object> getWord2vecVarMap() {
        return new HashMap<String, Object>() {{
            put("vectorLength", layerSize);
            put("useAdaGrad", useAdeGrad);
            put("negative", negative);
            put("window", window);
            put("alpha", learningRate.get());
            put("minAlpha", minLearningRate);
            put("iterations", numIterations);
            put("seed", seed);
            put("maxExp", MAX_EXP);
            put("batchSize", batchSize);
        }};
    }

    /**
     *  Training word2vec model on a given text corpus
     *
     * @param corpusRDD training corpus
     * @throws Exception
     */
    public void train(JavaRDD<String> corpusRDD) throws Exception {
        log.info("Start training ...");

        // SparkContext
        final JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());

        // Pre-defined variables
        Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
        Map<String, Object> word2vecVarMap = getWord2vecVarMap();

        // Variables to fill in train
        final JavaRDD<AtomicLong> sentenceWordsCountRDD;
        final JavaRDD<List vocabWordListRDD;
        final JavaPairRDD<List vocabWordListSentenceCumSumRDD;
        final VocabCache<VocabWord> vocabCache;
        final JavaRDD<Long> sentenceCumSumCountRDD;
        int maxRep = 1;

        // Start Training //
        //////////////////////////////////////
        log.info("Tokenization and building VocabCache ...");
        // Processing every sentence and make a VocabCache which gets fed into a LookupCache
        Broadcast<Map broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
        TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
        pipeline.buildVocabCache();
        pipeline.buildVocabWordListRDD();

        // Get total word count and put into word2vec variable map
        word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());

        // 2 RDDs: (vocab words list) and (sentence Count).Already cached
        sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
        vocabWordListRDD = pipeline.getVocabWordListRDD();

        // Get vocabCache and broad-casted vocabCache
        Broadcast<VocabCache vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
        vocabCache = vocabCacheBroadcast.getValue();

        //////////////////////////////////////
        log.info("Building Huffman Tree ...");
        // Building Huffman Tree would update the code and point in each of the vocabWord in vocabCache
/*
        We don't need to build tree here, since it was built earlier, at TextPipeline.buildVocabCache() call.

        Huffman huffman = new Huffman(vocabCache.vocabWords());
        huffman.build();
        huffman.applyIndexes(vocabCache);
*/
        //////////////////////////////////////
        log.info("Calculating cumulative sum of sentence counts ...");
        sentenceCumSumCountRDD =  new CountCumSum(sentenceWordsCountRDD).buildCumSum();

        //////////////////////////////////////
        log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
        vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD)
                .setName("vocabWordListSentenceCumSumRDD").cache();

        /////////////////////////////////////
        log.info("Broadcasting word2vec variables to workers ...");
        Broadcast<Map word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
        Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);



        /////////////////////////////////////
        log.info("Training word2vec sentences ...");
        FlatMapFunction firstIterFunc = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast);
        @SuppressWarnings("unchecked")
        JavaRDD< Pair indexSyn0UpdateEntryRDD =
                vocabWordListSentenceCumSumRDD.mapPartitions(firstIterFunc).map(new MapToPairFunction());

        // Get all the syn0 updates into a list in driver
        List<Pair syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();

        // Instantiate syn0
        INDArray syn0 = Nd4j.zeros(vocabCache.numWords(), layerSize);

        // Updating syn0 first pass: just add vectors obtained from different nodes
        Map<VocabWord, AtomicInteger> updates = new HashMap<>();
        Map<Long, Long> updaters = new HashMap<>();
        for (Pair<VocabWord, INDArray> syn0UpdateEntry : syn0UpdateEntries) {
            syn0.getRow(syn0UpdateEntry.getFirst().getIndex()).addi(syn0UpdateEntry.getSecond());

            // for proper averaging we need to divide resulting sums later, by the number of additions
            if (updates.containsKey(syn0UpdateEntry.getFirst())) {
                updates.get(syn0UpdateEntry.getFirst()).incrementAndGet();
            } else updates.put(syn0UpdateEntry.getFirst(), new AtomicInteger(1));

            if (!updaters.containsKey(syn0UpdateEntry.getFirst().getVocabId())) {
                updaters.put(syn0UpdateEntry.getFirst().getVocabId(), syn0UpdateEntry.getFirst().getAffinityId());
            }
        }

        // Updating syn0 second pass: average obtained vectors
        for (Map.Entry<VocabWord, AtomicInteger> entry: updates.entrySet()) {
            if (entry.getValue().get() > 1) {
                if (entry.getValue().get() > maxRep) maxRep = entry.getValue().get();
                syn0.getRow(entry.getKey().getIndex()).divi(entry.getValue().get());
            }
        }

        long totals = 0;
        for (Long up: updaters.values()) {
            totals += up;
        }



        vocab = vocabCache;
        InMemoryLookupTable<VocabWord> inMemoryLookupTable = new InMemoryLookupTable();
        Environment env = EnvironmentUtils.buildEnvironment();
        env.setNumCores(maxRep);
        env.setAvailableMemory(totals);
        update(env, Event.SPARK);
        inMemoryLookupTable.setVocab(vocabCache);
        inMemoryLookupTable.setVectorLength(layerSize);
        inMemoryLookupTable.setSyn0(syn0);
        lookupTable = inMemoryLookupTable;
        modelUtils.init(lookupTable);
    }



    public static class Builder {
        protected int nGrams = 1;
        protected int numIterations = 1;
        protected int minWordFrequency = 1;
        protected int numEpochs = 1;
        protected double learningRate = 0.025;
        protected double minLearningRate = 0.001;
        protected int windowSize = 5;
        protected double negative = 0;
        protected double sampling = 1e-5;
        protected long seed = 42L;
        protected boolean useAdaGrad = false;
        protected TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
        protected VectorsConfiguration configuration = new VectorsConfiguration();
        protected int layerSize;
        protected List<String> stopWords = new ArrayList<>();
        protected int batchSize = 100;
        protected boolean useUnk = false;
        private String tokenizer = "";
        private String tokenPreprocessor = "";

        /**
         * Creates Builder instance with default parameters set.
         */
        public Builder() {
            this(new VectorsConfiguration());
        }

        /**
         * Uses VectorsConfiguration bean to initialize Word2Vec model parameters
         *
         * @param configuration
         */
        public Builder(VectorsConfiguration configuration) {
            this.configuration = configuration;
            this.numIterations = configuration.getIterations();
            this.numEpochs = configuration.getEpochs();
            this.minLearningRate = configuration.getMinLearningRate();
            this.learningRate = configuration.getLearningRate();
            this.sampling = configuration.getSampling();
            this.negative = configuration.getNegative();
            this.minWordFrequency = configuration.getMinWordFrequency();
            this.seed = configuration.getSeed();
//            this.stopWords = configuration.get

            //  TODO: investigate this
            //this.hugeModelExpected = configuration.isHugeModelExpected();

            this.batchSize = configuration.getBatchSize();
            this.layerSize = configuration.getLayersSize();

          //  this.learningRateDecayWords = configuration.getLearningRateDecayWords();
            this.useAdaGrad = configuration.isUseAdaGrad();
            this.windowSize = configuration.getWindow();

            if (configuration.getStopList() != null) this.stopWords.addAll(configuration.getStopList());
        }

        /**
         * Specifies window size
         *
         * @param windowSize
         * @return
         */
        public Builder windowSize(int windowSize) {
            this.windowSize = windowSize;
            return this;
        }

        /**
         * Specifies negative sampling
         * @param negative
         * @return
         */
        public Builder negative(int negative) {
            this.negative = negative;
            return this;
        }

        /**
         * Specifies subsamplng value
         *
         * @param sampling
         * @return
         */
        public Builder sampling(double sampling) {
            this.sampling = sampling;
            return this;
        }

        /**
         * This method specifies initial learning rate for model
         *
         * @param lr
         * @return
         */
        public Builder learningRate(double lr) {
            this.learningRate = lr;
            return this;
        }

        /**
         * This method specifies bottom threshold for learning rate decay
         *
         * @param mlr
         * @return
         */
        public Builder minLearningRate(double mlr) {
            this.minLearningRate = mlr;
            return this;
        }

        /**
         * This method specifies number of iterations over batch on each node
         *
         * @param numIterations
         * @return
         */
        public Builder iterations(int numIterations) {
            this.numIterations = numIterations;
            return this;
        }

        /**
         * This method specifies number of epochs done over whole corpus
         *
         * PLEASE NOTE: NOT IMPLEMENTED
         *
         * @param numEpochs
         * @return
         */
        public Builder epochs(int numEpochs) {
            // TODO: implement epochs imitation for spark w2v
            this.numEpochs = numEpochs;
            return this;
        }

        /**
         * This method specifies minimum word frequency threshold. All words below this threshold will be ignored.
         *
         * @param minWordFrequency
         * @return
         */
        public Builder minWordFrequency(int minWordFrequency) {
            this.minWordFrequency = minWordFrequency;
            return this;
        }

        /**
         * This method specifies, if adaptive gradients should be used during model training
         *
         * @param reallyUse
         * @return
         */
        public Builder useAdaGrad(boolean reallyUse) {
            this.useAdaGrad = reallyUse;
            return this;
        }

        /**
         * Specifies random seed to be used during weights initialization;
         *
         * @param seed
         * @return
         */
        public Builder seed(long seed) {
            this.seed = seed;
            return this;
        }

        /**
         * Specifies TokenizerFactory to be used for tokenization
         *
         * PLEASE NOTE: You can't use anonymous implementation here
         *
         * @param factory
         * @return
         */
        public Builder tokenizerFactory(@NonNull TokenizerFactory factory) {
            this.tokenizer = factory.getClass().getCanonicalName();

            if (factory.getTokenPreProcessor() != null) {
                this.tokenPreprocessor = factory.getTokenPreProcessor().getClass().getCanonicalName();
            } else {
                this.tokenPreprocessor = "";
            }

            return this;
        }

        /**
         * Specifies TokenizerFactory class to be used for tokenization
         *
         *
         * @param tokenizer class name for tokenizerFactory
         * @return
         */
        public Builder tokenizerFactory(@NonNull String tokenizer) {
            this.tokenizer = tokenizer;
            return this;
        }

        /**
         * Specifies TokenPreProcessor class to be used during tokenization
         *
         *
         * @param tokenPreprocessor class name for tokenPreProcessor
         * @return
         */
        public Builder tokenPreprocessor(@NonNull String tokenPreprocessor) {
            this.tokenPreprocessor = tokenPreprocessor;
            return this;
        }

        /**
         * Specifies output vector's dimensions
         *
         * @param layerSize
         * @return
         */
        public Builder layerSize(int layerSize) {
            this.layerSize = layerSize;
            return this;
        }

        /**
         * Specifies N of n-Grams :)
         *
         * @param nGrams
         * @return
         */
        public Builder setNGrams(int nGrams) {
            this.nGrams = 1;
            return this;
        }

        /**
         * This method defines list of stop-words, that are to be ignored during vocab building and training
         *
         * @param stopWords
         * @return
         */
        public Builder stopWords(@NonNull List<String> stopWords) {
            for (String word: stopWords) {
                if (!this.stopWords.contains(word)) this.stopWords.add(word);
            }
            return this;
        }

        /**
         * Specifies the size of mini-batch, used in single iteration during training
         *
         * @param batchSize
         * @return
         */
        public Builder batchSize(int batchSize) {
            this.batchSize = batchSize;
            return this;
        }

        /**
         * Specifies, if UNK word should be used instead of words that are absent in vocab
         *
         * @param reallyUse
         * @return
         */
        public Builder useUnknown(boolean reallyUse) {
            this.useUnk = reallyUse;
            return this;
        }

        public Word2Vec build() {
            Word2Vec ret = new Word2Vec();

            this.configuration.setLearningRate(this.learningRate);
            this.configuration.setLayersSize(layerSize);
            this.configuration.setWindow(windowSize);
            this.configuration.setMinWordFrequency(minWordFrequency);
            this.configuration.setIterations(numIterations);
            this.configuration.setSeed(seed);
            this.configuration.setMinLearningRate(minLearningRate);
            this.configuration.setSampling(this.sampling);
            this.configuration.setUseAdaGrad(useAdaGrad);
            this.configuration.setNegative(negative);
            this.configuration.setEpochs(this.numEpochs);
            this.configuration.setBatchSize(this.batchSize);
            this.configuration.setStopList(this.stopWords);

            ret.configuration = this.configuration;

            ret.numEpochs = this.numEpochs;
            ret.numIterations = this.numIterations;
            ret.minWordFrequency = this.minWordFrequency;
            ret.learningRate.set(this.learningRate);
            ret.minLearningRate = this.minLearningRate;
            ret.sampling = this.sampling;
            ret.negative = this.negative;
            ret.layerSize = this.layerSize;
            ret.window = this.windowSize;
            ret.useAdeGrad = this.useAdaGrad;
            ret.stopWords = this.stopWords;
            ret.batchSize = this.batchSize;
            ret.useUnknown = this.useUnk;

            ret.tokenizer = this.tokenizer;
            ret.tokenPreprocessor = this.tokenPreprocessor;

            return ret;
        }
    }
}

Other Java examples (source code examples)

Here is a short list of links related to this Java Word2Vec.java source code file:



my book on functional programming

 

new blog posts

 

Copyright 1998-2019 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.