alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (VocabConstructor.java)

This example Java source code file (VocabConstructor.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

arraylist, atomiclong, builder, current, illegalstateexception, invertedindex, linkedblockingqueue, list, sequenceelement, sequenceiterator, string, threading, threads, util, vocabcache, vocabconstructor

The VocabConstructor.java Java example source code

package org.deeplearning4j.models.word2vec.wordstore;

import lombok.Data;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicLong;

/**
 *
 * This class can be used to build joint vocabulary from special sources, that should be treated separately.
 * I.e. words from one source should have minWordFrequency set to 1, while the rest of corpus should have minWordFrequency set to 5.
 * So, here's the way to deal with it.
 *
 * It also can be used to simply build vocabulary out of arbitrary number of Sequences derived from arbitrary number of SequenceIterators
 *
 * @author raver119@gmail.com
 */
public class VocabConstructor<T extends SequenceElement> {
    private List<VocabSource sources = new ArrayList<>();
    private VocabCache<T> cache;
    private List<String> stopWords;
    private boolean useAdaGrad = false;
    private boolean fetchLabels = false;
    private int limit;
    private AtomicLong seqCount = new AtomicLong(0);
    private InvertedIndex<T> index;

    protected static final Logger log = LoggerFactory.getLogger(VocabConstructor.class);

    private VocabConstructor() {

    }

    /**
     * Placeholder for future implementation
     * @return
     */
    protected WeightLookupTable<T> buildExtendedLookupTable() {
        return null;
    }

    /**
     * Placeholder for future implementation
     * @return
     */
    protected VocabCache<T> buildExtendedVocabulary() {
        return null;
    }

    /**
     * This method transfers existing WordVectors model into current one
     *
     * @param wordVectors
     * @return
     */
    @SuppressWarnings("unchecked") // method is safe, since all calls inside are using generic SequenceElement methods
    public VocabCache<T> buildMergedVocabulary(@NonNull WordVectors wordVectors, boolean fetchLabels) {
        return buildMergedVocabulary((VocabCache<T>) wordVectors.vocab(), fetchLabels);
    }


    /**
     * This method returns total number of sequences passed through VocabConstructor
     *
     * @return
     */
    public long getNumberOfSequences() {
        return seqCount.get();
    }

    /**
     * This method transfers existing vocabulary into current one
     *
     * Please note: this method expects source vocabulary has Huffman tree indexes applied
     *
     * @param vocabCache
     * @return
     */
    public VocabCache<T> buildMergedVocabulary(@NonNull VocabCache vocabCache, boolean fetchLabels) {
        if (cache == null) cache = new AbstractCache.Builder<T>().build();
        for (int t = 0; t < vocabCache.numWords(); t++) {
            String label = vocabCache.wordAtIndex(t);
            if (label == null) continue;
            T element = vocabCache.wordFor(label);

            // skip this element if it's a label, and user don't want labels to be merged
            if (!fetchLabels && element.isLabel()) continue;

            //element.setIndex(t);
            cache.addToken(element);
            cache.addWordToIndex(element.getIndex(), element.getLabel());

            // backward compatibility code
            cache.putVocabWord(element.getLabel());
        }

        if (cache.numWords() == 0) throw new IllegalStateException("Source VocabCache has no indexes available, transfer is impossible");

        /*
            Now, when we have transferred vocab, we should roll over iterator, and  gather labels, if any
         */

        log.info("Vocab size before labels: " + cache.numWords());

        if (fetchLabels) {
            for(VocabSource<T> source: sources) {
                SequenceIterator<T> iterator = source.getIterator();
                iterator.reset();

                while (iterator.hasMoreSequences()) {
                    Sequence<T> sequence = iterator.nextSequence();
                    seqCount.incrementAndGet();

                    for (T label: sequence.getSequenceLabels()) {
                        if (!cache.containsWord(label.getLabel())) {
                            label.markAsLabel(true);
                            label.setSpecial(true);

                            label.setIndex(cache.numWords());

                            cache.addToken(label);
                            cache.addWordToIndex(label.getIndex(), label.getLabel());

                            // backward compatibility code
                            cache.putVocabWord(label.getLabel());

                            log.info("Adding label ["+label.getLabel()+"]: " + cache.wordFor(label.getLabel()));
                        } else log.info("Label ["+label.getLabel()+"] already exists: " + cache.wordFor(label.getLabel()));
                    }
                }
            }
        }

        log.info("Vocab size after labels: " + cache.numWords());

        return cache;
    }


    /**
     * This method scans all sources passed through builder, and returns all words as vocab.
     * If TargetVocabCache was set during instance creation, it'll be filled too.
     *
     *
     * @return
     */
    public VocabCache<T> buildJointVocabulary(boolean resetCounters, boolean buildHuffmanTree) {
        if (resetCounters && buildHuffmanTree) throw new IllegalStateException("You can't reset counters and build Huffman tree at the same time!");

        if (cache == null) cache = new AbstractCache.Builder<T>().build();
        log.debug("Target vocab size before building: [" + cache.numWords() + "]");
        final AtomicLong elementsCounter = new AtomicLong(0);

        AbstractCache<T> topHolder = new AbstractCache.Builder()
                .minElementFrequency(0)
                .build();

        int cnt = 0;
        for(VocabSource<T> source: sources) {
            SequenceIterator<T> iterator = source.getIterator();
            iterator.reset();

            log.debug("Trying source iterator: ["+ cnt+"]");
            log.debug("Target vocab size before building: [" + cache.numWords() + "]");
            cnt++;

            AbstractCache<T> tempHolder = new AbstractCache.Builder().build();

            int sequences = 0;
            long counter = 0;
            while (iterator.hasMoreSequences()) {
                Sequence<T> document = iterator.nextSequence();
                seqCount.incrementAndGet();

                tempHolder.incrementTotalDocCount();

                Map<String, AtomicLong> seqMap = new HashMap<>();
              //  log.info("Sequence length: ["+ document.getElements().size()+"]");

                if (fetchLabels) {
                    T labelWord = document.getSequenceLabel();
                    labelWord.setSpecial(true);
                    labelWord.markAsLabel(true);
                    labelWord.setElementFrequency(1);

                    tempHolder.addToken(labelWord);
                }

                List<String> tokens = document.asLabels();
                for (String token: tokens) {
                    if (stopWords !=null && stopWords.contains(token)) continue;
                    if (token == null || token.isEmpty()) continue;

                    if (!tempHolder.containsWord(token)) {
                        T element = document.getElementByLabel(token);
                        element.setElementFrequency(1);
                        tempHolder.addToken(element);
                        elementsCounter.incrementAndGet();
                        counter++;

                        // if there's no such element in tempHolder, it's safe to set seqCount to 1
                        element.setSequencesCount(1);
                        seqMap.put(token, new AtomicLong(0));
                    } else {
                        counter++;
                        tempHolder.incrementWordCount(token);

                        // if element exists in tempHolder, we should update it seqCount, but only once per sequence
                        if (!seqMap.containsKey(token)) {
                            seqMap.put(token, new AtomicLong(1));
                            T element = tempHolder.wordFor(token);
                            element.incrementSequencesCount();
                        }

                        if (index != null) {
                            if (document.getSequenceLabel() != null) {
                                index.addWordsToDoc(index.numDocuments(), document.getElements(), document.getSequenceLabel());
                            } else {
                                index.addWordsToDoc(index.numDocuments(),document.getElements());
                            }
                        }
                    }
                }

                sequences++;
                if (seqCount.get() % 100000 == 0) log.info("Sequences checked: [" + seqCount.get() +"], Current vocabulary size: [" + elementsCounter.get() +"]");
            }
            // apply minWordFrequency set for this source
            log.debug("Vocab size before truncation: [" + tempHolder.numWords() + "],  NumWords: [" + tempHolder.totalWordOccurrences()+ "], sequences parsed: [" + sequences+ "], counter: ["+counter+"]");
            if (source.getMinWordFrequency() > 0) {
                LinkedBlockingQueue<String> labelsToRemove = new LinkedBlockingQueue<>();
                for (T element : tempHolder.vocabWords()) {
                    if (element.getElementFrequency() < source.getMinWordFrequency() && !element.isSpecial() && !element.isLabel())
                        labelsToRemove.add(element.getLabel());
                }

                for (String label: labelsToRemove) {
                    tempHolder.removeElement(label);
                }
            }

            log.debug("Vocab size after truncation: [" + tempHolder.numWords() + "],  NumWords: [" + tempHolder.totalWordOccurrences()+ "], sequences parsed: [" + sequences+ "], counter: ["+counter+"]");

            // at this moment we're ready to transfer
            topHolder.importVocabulary(tempHolder);
        }

        // at this moment, we have vocabulary full of words, and we have to reset counters before transfer everything back to VocabCache

            //topHolder.resetWordCounters();



        cache.importVocabulary(topHolder);

        if (resetCounters) {
            for (T element: cache.vocabWords()) {
                element.setElementFrequency(0);
            }
            cache.updateWordsOccurencies();
        }

        if (buildHuffmanTree) {
            Huffman huffman = new Huffman(cache.vocabWords());
            huffman.build();
            huffman.applyIndexes(cache);
            //topHolder.updateHuffmanCodes();

            if (limit > 0) {
                LinkedBlockingQueue<String> labelsToRemove = new LinkedBlockingQueue<>();
                for (T element : cache.vocabWords()) {
                    if (element.getIndex() > limit && !element.isSpecial() && !element.isLabel())
                        labelsToRemove.add(element.getLabel());
                }

                for (String label: labelsToRemove) {
                    cache.removeElement(label);
                }
            }
        }

        log.info("Sequences checked: [" + seqCount.get() +"], Current vocabulary size: [" + cache.numWords() +"]");
        return cache;
    }

    public static class Builder<T extends SequenceElement> {
        private List<VocabSource sources = new ArrayList<>();
        private VocabCache<T> cache;
        private List<String> stopWords = new ArrayList<>();
        private boolean useAdaGrad = false;
        private boolean fetchLabels = false;
        private InvertedIndex<T> index;
        private int limit;

        public Builder() {

        }

        /**
         * This method sets the limit to resulting vocabulary size.
         *
         * PLEASE NOTE:  This method is applicable only if huffman tree is built.
         *
         * @param limit
         * @return
         */
        public Builder<T> setEntriesLimit(int limit) {
            this.limit = limit;
            return this;
        }

        /**
         * Defines, if adaptive gradients should be created during vocabulary mastering
         *
         * @param useAdaGrad
         * @return
         */
        protected Builder<T> useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        /**
         * After temporary internal vocabulary is built, it will be transferred to target VocabCache you pass here
         *
         * @param cache target VocabCache
         * @return
         */
        public Builder<T> setTargetVocabCache(@NonNull VocabCache cache) {
            this.cache = cache;
            return this;
        }

        /**
         * Adds SequenceIterator for vocabulary construction.
         * Please note, you can add as many sources, as you wish.
         *
         * @param iterator SequenceIterator to build vocabulary from
         * @param minElementFrequency elements with frequency below this value will be removed from vocabulary
         * @return
         */
        public Builder<T> addSource(@NonNull SequenceIterator iterator, int minElementFrequency) {
            sources.add(new VocabSource<T>(iterator, minElementFrequency));
            return this;
        }
/*
        public Builder<T> addSource(LabelAwareIterator iterator, int minWordFrequency) {
            sources.add(new VocabSource(iterator, minWordFrequency));
            return this;
        }

        public Builder<T> addSource(SentenceIterator iterator, int minWordFrequency) {
            sources.add(new VocabSource(new SentenceIteratorConverter(iterator), minWordFrequency));
            return this;
        }
        */
/*
        public Builder setTokenizerFactory(@NonNull TokenizerFactory factory) {
            this.tokenizerFactory = factory;
            return this;
        }
*/
        public Builder<T> setStopWords(@NonNull List stopWords) {
            this.stopWords = stopWords;
            return this;
        }

        /**
         * Sets, if labels should be fetched, during vocab building
         *
         * @param reallyFetch
         * @return
         */
        public Builder<T> fetchLabels(boolean reallyFetch) {
            this.fetchLabels = reallyFetch;
            return this;
        }

        public Builder<T> setIndex(InvertedIndex index) {
            this.index = index;
            return this;
        }

        public VocabConstructor<T> build() {
            VocabConstructor<T> constructor = new VocabConstructor();
            constructor.sources = this.sources;
            constructor.cache = this.cache;
            constructor.stopWords = this.stopWords;
            constructor.useAdaGrad = this.useAdaGrad;
            constructor.fetchLabels = this.fetchLabels;
            constructor.limit = this.limit;
            constructor.index = this.index;

            return constructor;
        }
    }

    @Data
    private static class VocabSource<T extends SequenceElement> {
        @NonNull private SequenceIterator<T> iterator;
        @NonNull private int minWordFrequency;
    }
}

Other Java examples (source code examples)

Here is a short list of links related to this Java VocabConstructor.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.