alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (AbstractCoOccurrences.java)

This example Java source code file (AbstractCoOccurrences.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

abstractcooccurrences, atomicboolean, builder, cooccurrencescalculatorthread, countmap, double, exception, file, override, pair, runtimeexception, sequenceiterator, string, util

The AbstractCoOccurrences.java Java example source code

package org.deeplearning4j.models.glove;

import lombok.NonNull;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.glove.count.*;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.iterators.FilteredSequenceIterator;
import org.deeplearning4j.models.sequencevectors.iterators.SynchronizedSequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SynchronizedSentenceIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;

/**
 * This class implements building cooccurrence map for abstract training corpus.
 * However it's performance rather low, due to exsessive IO that happens in ShadowCopyThread
 *
 * PLEASE NOTE: Current implementation involves massive IO, and it should be rewritter as soon as ND4j gets sparse arrays support
 *
 * @author raver119@gmail.com
 */
public class AbstractCoOccurrences<T extends SequenceElement> implements Serializable {

    protected boolean symmetric;
    protected int windowSize;
    protected VocabCache<T> vocabCache;
    protected SequenceIterator<T> sequenceIterator;

    // please note, we need enough room for ShadowCopy thread, that's why -1 there
    protected int workers = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);

    // target file, where text with cooccurrencies should be saved
    protected File targetFile;

    protected ReentrantReadWriteLock lock = new ReentrantReadWriteLock();

    protected long memory_threshold = 0;

    private ShadowCopyThread shadowThread;

//    private Counter<Integer> sentenceOccurrences = Util.parallelCounter();
    //private CounterMap<T, T> coOccurrenceCounts = Util.parallelCounterMap();
    private volatile CountMap<T> coOccurrenceCounts = new CountMap<>();
    //private Counter<Integer> occurrenceAllocations = Util.parallelCounter();
    //private List<Pair coOccurrences;
    private AtomicLong processedSequences = new AtomicLong(0);


    protected static final Logger logger = LoggerFactory.getLogger(AbstractCoOccurrences.class);

    // this method should be private, to avoid non-configured instantiation
    private AbstractCoOccurrences() {
        ;
    }

    /**
     * This method returns cooccurrence distance weights for two SequenceElements
     *
     * @param element1
     * @param element2
     * @return distance weight
     */
    public double getCoOccurrenceCount(@NonNull T element1, @NonNull T element2) {
        return coOccurrenceCounts.getCount(element1, element2);
    }

    /**
     * This method returns estimated memory footrpint, based on current CountMap content
     * @return
     */
    protected long getMemoryFootprint() {
        // TODO: implement this method. It should return approx. memory used by appropriate CountMap
        try {
            lock.readLock().lock();
            return ((long) coOccurrenceCounts.size()) * 24L * 5L;
        } finally {
            lock.readLock().unlock();
        }
    }

    /**
     * This memory returns memory threshold, defined as 1/2 of memory allowed for allocation
     * @return
     */
    protected long getMemoryThreshold() {
        return memory_threshold / 2L;
    }

    public void fit() {
        shadowThread = new ShadowCopyThread();
        shadowThread.start();

        // we should reset iterator before counting cooccurrences
        sequenceIterator.reset();

        List<CoOccurrencesCalculatorThread> threads = new ArrayList<>();
        for (int x = 0; x < workers; x++) {
            threads.add(x, new CoOccurrencesCalculatorThread(x, new FilteredSequenceIterator<T>(new SynchronizedSequenceIterator(sequenceIterator), vocabCache), processedSequences));
            threads.get(x).start();
        }

        for (int x = 0; x < workers; x++) {
            try {
                threads.get(x).join();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        shadowThread.finish();
        logger.info("CoOccurrences map was built.");
    }

    /**
     *
     *  This method returns iterator with elements pairs and their weights. Resulting iterator is safe to use in multi-threaded environment.
     *
     * Developer's note: thread safety on received iterator is delegated to PrefetchedSentenceIterator
     * @return
     */
    public Iterator<Pair> iterator() {
        final SentenceIterator iterator;

        try {
            iterator = new SynchronizedSentenceIterator(new PrefetchingSentenceIterator.Builder(new BasicLineIterator(targetFile))
                    .setFetchSize(500000)
                    .build());

        } catch (Exception e) {
            logger.error("Target file was not found on last stage!");
            throw new RuntimeException(e);
        }
        return new Iterator<Pair>() {
            /*
                    iterator should be built on top of current text file with all pairs
             */

            @Override
            public boolean hasNext() {
                return iterator.hasNext();
            }

            @Override
            public Pair<Pair next() {
                String line = iterator.nextSentence();
                String[] strings = line.split(" ");

                T element1 = vocabCache.elementAtIndex(Integer.valueOf(strings[0]));
                T element2 = vocabCache.elementAtIndex(Integer.valueOf(strings[1]));
                Double weight = Double.valueOf(strings[2]);

                return new Pair<>(new Pair(element1, element2), weight);
            }

            @Override
            public void remove() {
                throw new UnsupportedOperationException("remove() method can't be supported on read-only interface");
            }
        };
    }

    public static class Builder<T extends SequenceElement> {

        protected boolean symmetric;
        protected int windowSize = 5;
        protected VocabCache<T> vocabCache;
        protected SequenceIterator<T> sequenceIterator;
        protected int workers = Runtime.getRuntime().availableProcessors();
        protected File target;
        protected long maxmemory = Runtime.getRuntime().maxMemory();

        public Builder() {

        }

        public Builder<T> symmetric(boolean reallySymmetric) {
            this.symmetric = reallySymmetric;
            return this;
        }

        public Builder<T> windowSize(int windowSize) {
            this.windowSize = windowSize;
            return this;
        }

        public Builder<T> vocabCache(@NonNull VocabCache cache) {
            this.vocabCache = cache;
            return this;
        }

        public Builder<T> iterate(@NonNull SequenceIterator iterator) {
            this.sequenceIterator = new SynchronizedSequenceIterator<T>(iterator);
            return this;
        }

        public Builder<T> workers(int numWorkers) {
            this.workers = numWorkers;
            return this;
        }

        /**
         * This method allows you to specify maximum memory available for CoOccurrence map builder.
         *
         * Please note: this option can be considered a debugging method. In most cases setting proper -Xmx argument set to JVM is enough to limit this algorithm.
         * Please note: this option won't override -Xmx JVM value.
         *
         * @param gbytes memory available, in GigaBytes
         * @return
         */
        public Builder<T> maxMemory(int gbytes) {
            if (gbytes > 0) this.maxmemory = Math.max(gbytes - 1, 1) * 1024 * 1024 * 1024L;

            return this;
        }

        /**
         * Path to save cooccurrence map after construction.
         * If targetFile is not specified, temporary file will be used.
         *
         * @param path
         * @return
         */
        public Builder<T> targetFile(@NonNull String path) {
            this.targetFile(new File(path));
            return this;
        }

        /**
         * Path to save cooccurrence map after construction.
         * If targetFile is not specified, temporary file will be used.
         *
         * @param file
         * @return
         */
        public Builder<T> targetFile(@NonNull File file) {
            this.target = file;
            return this;
        }

        public AbstractCoOccurrences<T> build() {
            AbstractCoOccurrences<T> ret = new AbstractCoOccurrences<>();
            ret.sequenceIterator = this.sequenceIterator;
            ret.windowSize = this.windowSize;
            ret.vocabCache = this.vocabCache;
            ret.symmetric = this.symmetric;
            ret.workers = this.workers;

            if (this.maxmemory < 1) this.maxmemory = Runtime.getRuntime().maxMemory();
            ret.memory_threshold = this.maxmemory;


            logger.info("Actual memory limit: ["+ this.maxmemory +"]");

            // use temp file, if no target file was specified
            try {
                if (this.target == null) this.target = File.createTempFile("cooccurrence", "map");
                this.target.deleteOnExit();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }

            ret.targetFile = this.target;

            return ret;
        }
    }

    private class CoOccurrencesCalculatorThread extends Thread implements Runnable {

        private final SequenceIterator<T> iterator;
        private final AtomicLong sequenceCounter;
        private int threadId;

        public CoOccurrencesCalculatorThread(int threadId, @NonNull SequenceIterator<T> iterator, @NonNull AtomicLong sequenceCounter) {
            this.iterator = iterator;
            this.sequenceCounter = sequenceCounter;
            this.threadId = threadId;

            this.setName("CoOccurrencesCalculatorThread " + threadId);
        }

        @Override
        public void run() {
            while (iterator.hasMoreSequences()) {
                Sequence<T> sequence = iterator.nextSequence();

                List<String> tokens = new ArrayList<>(sequence.asLabels());
    //            logger.info("Tokens size: " + tokens.size());
                for (int x = 0; x < sequence.getElements().size(); x++) {
                    int wordIdx = vocabCache.indexOf(tokens.get(x));
                    if (wordIdx < 0) continue;
                    String w1 = vocabCache.wordFor(tokens.get(x)).getLabel();

                    // THIS iS SAFE TO REMOVE, NO CHANCE WE'll HAVE UNK WORD INSIDE SEQUENCE
                    /*if(w1.equals(Glove.UNK))
                        continue;
                    */

                    int windowStop = Math.min(x + windowSize + 1,tokens.size());
                    for(int j = x; j < windowStop; j++) {
                        int otherWord = vocabCache.indexOf(tokens.get(j));
                        if (otherWord < 0) continue;
                        String w2 = vocabCache.wordFor(tokens.get(j)).getLabel();

                        if(w2.equals(Glove.DEFAULT_UNK) || otherWord == wordIdx) {
                            continue;
                        }


                        T tokenX  = vocabCache.wordFor(tokens.get(x));
                        T tokenJ = vocabCache.wordFor(tokens.get(j));
                        double nWeight = 1.0 / (j - x + Nd4j.EPS_THRESHOLD);

                        while (getMemoryFootprint() >= getMemoryThreshold()) {
                            try {
                                shadowThread.invoke();
                                /*lock.readLock().lock();
                                int size = coOccurrenceCounts.size();
                                lock.readLock().unlock();
                                */
                                if (threadId == 0) logger.debug("Memory consuimption > threshold: {footrpint: ["+ getMemoryFootprint()+"], threshold: [" + getMemoryThreshold() +"] }");
                                Thread.sleep(10000);
                            } catch (Exception e) {
                                throw new RuntimeException(e);
                            } finally {

                            }
                        }
                        /*
                        if (getMemoryFootprint() == 0) {
                            logger.info("Zero size!");
                        }
                        */

                        try {
                            lock.readLock().lock();
                            if (wordIdx < otherWord) {
                                coOccurrenceCounts.incrementCount(tokenX, tokenJ, nWeight);
                                if (symmetric) {
                                    coOccurrenceCounts.incrementCount(tokenJ, tokenX, nWeight);
                                }
                            } else {
                                coOccurrenceCounts.incrementCount(tokenJ, tokenX, nWeight);

                                if (symmetric) {
                                    coOccurrenceCounts.incrementCount(tokenX, tokenJ, nWeight);
                                }
                            }
                        } finally {
                            lock.readLock().unlock();
                        }
                    }
                }

                sequenceCounter.incrementAndGet();
            }
        }
    }

    /**
     * This class is designed to provide shadow copy functionality for CoOccurence maps, since with proper corpus size you can't fit such a map into memory
     *
     */
    private class ShadowCopyThread extends Thread implements Runnable {

        private AtomicBoolean isFinished = new AtomicBoolean(false);
        private AtomicBoolean isTerminate = new AtomicBoolean(false);
        private AtomicBoolean isInvoked = new AtomicBoolean(false);
        private AtomicBoolean shouldInvoke = new AtomicBoolean(false);

        // file that contains resuts from previous runs
        private File[] tempFiles;
        private RoundCount counter;

        public ShadowCopyThread() {
            try {

                counter = new RoundCount(1);
                tempFiles = new File[2];

                tempFiles[0] = File.createTempFile("aco", "tmp");
                tempFiles[1] = File.createTempFile("aco", "tmp");

                tempFiles[0].deleteOnExit();
                tempFiles[1].deleteOnExit();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }

            this.setName("ACO ShadowCopy thread");
        }

        @Override
        public void run() {
            /*
                  Basic idea is pretty simple: run quetly, untill memory gets filled up to some high volume.
                  As soon as this happens - execute shadow copy.
            */
            while (!isFinished.get() && !isTerminate.get()) {
                // check used memory. if memory use below threshold - sleep for a while. if above threshold - invoke copier

                if (getMemoryFootprint() > getMemoryThreshold()  || (shouldInvoke.get() && !isInvoked.get())) {
                    // we'll just invoke copier, nothing else
                    shouldInvoke.compareAndSet(true, false);
                    invokeBlocking();
                } else {
                    try {
                        /*
                               commented and left here for future debugging purposes, if needed

                                //lock.readLock().lock();
                                //int size = coOccurrenceCounts.size();
                                //lock.readLock().unlock();
                                //logger.info("Current memory situation: {size: [" +size+ "], footprint: [" + getMemoryFootprint()+"], threshold: ["+ getMemoryThreshold() +"]}");
                         */
                        Thread.sleep(1000);
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
            }
        }

        /**
         * This methods advises shadow copy process to start
         */
        public void invoke() {
            shouldInvoke.compareAndSet(false, true);
        }

        /**
         * This methods dumps cooccurrence map into save file.
         * Please note: this method is synchronized and will block, until complete
         */
        public synchronized void invokeBlocking() {
            if (getMemoryFootprint() < getMemoryThreshold() && !isFinished.get()) return;

            int numberOfLinesSaved = 0;

            isInvoked.set(true);

            logger.debug("Memory purge started.");

            /*
                Basic plan:
                    1. Open temp file
                    2. Read that file line by line
                    3. For each read line do synchronization in memory > new file direction
             */

            counter.tick();

            CountMap<T> localMap;
            try {
                // in any given moment there's going to be only 1 WriteLock, due to invokeBlocking() being synchronized call
                lock.writeLock().lock();



                // obtain local copy of CountMap
                 localMap = coOccurrenceCounts;

                // set new CountMap, and release write lock
                coOccurrenceCounts = new CountMap<T>();
            } catch (Exception e) {
                throw new RuntimeException(e);
            } finally {
                lock.writeLock().unlock();
            }

            try {

                File file = null;
                if (!isFinished.get()) {
                    file = tempFiles[counter.previous()];
                } else file = targetFile;


            //    PrintWriter pw = new PrintWriter(file);

                int linesRead = 0;

                logger.debug("Saving to: ["+ counter.get()+"], Reading from: [" + counter.previous()+"]");
                CoOccurenceReader<T> reader = new BinaryCoOccurrenceReader<>(tempFiles[counter.previous()], vocabCache, localMap);
                CoOccurrenceWriter<T> writer = (isFinished.get()) ? new ASCIICoOccurrenceWriter(targetFile): new BinaryCoOccurrenceWriter(tempFiles[counter.get()]);
                while (reader.hasMoreObjects()) {
                    CoOccurrenceWeight<T> line = reader.nextObject();

                    if (line != null) {
                        writer.writeObject(line);
                        numberOfLinesSaved++;
                        linesRead++;
                    }
                }
                reader.finish();

                logger.debug("Lines read: [" + linesRead + "]");

                //now, we can dump the rest of elements, which were not presented in existing dump
                Iterator<Pair iterator = localMap.getPairIterator();
                while (iterator.hasNext()) {
                    Pair<T, T> pair = iterator.next();
                    double mWeight = localMap.getCount(pair);
                    CoOccurrenceWeight<T> object = new CoOccurrenceWeight<>();
                    object.setElement1(pair.getFirst());
                    object.setElement2(pair.getSecond());
                    object.setWeight(mWeight);

                    writer.writeObject(object);

                    numberOfLinesSaved++;
                    //      if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]");
                }

                writer.finish();

            /*
                SentenceIterator sIterator =  new PrefetchingSentenceIterator.Builder(new BasicLineIterator(tempFiles[counter.get()]))
                        .setFetchSize(500000)
                        .build();


                int linesRead = 0;
                while (sIterator.hasNext()) {
                    //List<Writable> list = new ArrayList<>(reader.next());
                    String sentence = sIterator.nextSentence();
                    if (sentence == null || sentence.isEmpty()) continue;
                    String[] strings = sentence.split(" ");


                    // first two elements are integers - vocab indexes
                    //T element1 = vocabCache.wordFor(vocabCache.wordAtIndex(list.get(0).toInt()));
                    //T element2 = vocabCache.wordFor(vocabCache.wordAtIndex(list.get(1).toInt()));
                    T element1 = vocabCache.elementAtIndex(Integer.valueOf(strings[0]));
                    T element2 = vocabCache.elementAtIndex(Integer.valueOf(strings[1]));

                    // getting third element, previously stored weight
                    double sWeight = Double.valueOf(strings[2]);  // list.get(2).toDouble();

                    // now, since we have both elements ready, we can check this pair against inmemory map
                        double mWeight = localMap.getCount(element1, element2);
                        if (mWeight <= 0) {
                            // this means we have no such pair in memory, so we'll do nothing to sWeight
                        } else {
                            // since we have new weight value in memory, we should update sWeight value before moving it off memory
                            sWeight += mWeight;

                            // original pair can be safely removed from CountMap
                            localMap.removePair(element1,element2);
                        }

                        StringBuilder builder = new StringBuilder().append(element1.getIndex()).append(" ").append(element2.getIndex()).append(" ").append(sWeight);
                        pw.println(builder.toString());
                        numberOfLinesSaved++;
                        linesRead++;

                   // if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]");
                  //  if (linesRead % 100000 == 0) logger.info("Lines read: [" + linesRead +"]");
                }
                */
/*
                logger.info("Lines read: [" + linesRead + "]");

                //now, we can dump the rest of elements, which were not presented in existing dump
                Iterator<Pair iterator = localMap.getPairIterator();
                while (iterator.hasNext()) {
                    Pair<T, T> pair = iterator.next();
                    double mWeight = localMap.getCount(pair);

                    StringBuilder builder = new StringBuilder().append(pair.getFirst().getIndex()).append(" ").append(pair.getFirst().getIndex()).append(" ").append(mWeight);
                    pw.println(builder.toString());
                    numberOfLinesSaved++;

              //      if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]");
                }

                pw.flush();
                pw.close();

*/

                // just a hint for gc
                localMap = null;
                //sIterator.finish();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }

            logger.info("Number of word pairs saved so far: [" + numberOfLinesSaved + "]");
            isInvoked.set(false);
        }

        /**
         * This method provides soft finish ability for shadow copy process.
         * Please note: it's blocking call, since it requires for final merge.
         */
        public void finish() {
            if (this.isFinished.get()) return;

            this.isFinished.set(true);
            invokeBlocking();
        }

        /**
         * This method provides hard fiinish ability for shadow copy process
         */
        public void terminate() {
            this.isTerminate.set(true);
        }
    }
}

Other Java examples (source code examples)

Here is a short list of links related to this Java AbstractCoOccurrences.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.