alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (Tsne.java)

This example Java source code file (Tsne.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

bufferedwriter, builder, indarray, indarrayindex, iteration, pair, string, stringbuffer, tsne, util, value

The Tsne.java Java example source code

package org.deeplearning4j.plot;

import com.google.common.primitives.Ints;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.nd4j.linalg.dimensionalityreduction.PCA;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.indexing.functions.Zero;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.learning.AdaGrad;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;

import static org.nd4j.linalg.factory.Nd4j.*;
import static org.nd4j.linalg.ops.transforms.Transforms.*;

/**
 * dl4j port of original t-sne algorithm described/implemented by van der Maaten and Hinton
 *
 * DECOMPOSED VERSION, DO NOT USE IT EVER
 *
 * @author raver119@gmail.com
 * @author Adam Gibson
 */
public class Tsne {
    protected int maxIter = 1000;
    protected double realMin = Nd4j.EPS_THRESHOLD;
    protected double initialMomentum = 0.5;
    protected double finalMomentum = 0.8;
    protected  double minGain = 1e-2;
    protected double momentum = initialMomentum;
    protected int switchMomentumIteration = 100;
    protected boolean normalize = true;
    protected boolean usePca = false;
    protected int stopLyingIteration = 250;
    protected double tolerance = 1e-5;
    protected double learningRate = 500;
    protected AdaGrad adaGrad;
    protected boolean useAdaGrad = true;
    protected double perplexity = 30;
    //protected INDArray gains,yIncs;
    protected INDArray Y;
    protected transient IterationListener iterationListener;

    protected static final Logger logger = LoggerFactory.getLogger(Tsne.class);

    protected Tsne() {
        ;
    }

    protected void init() {

    }

    public INDArray calculate(INDArray X, int targetDimensions, double perplexity) {
        // pca hook
        if (usePca) {
            X = PCA.pca(X, Math.min(50,X.columns()),normalize);
        } else if(normalize) {
            X.subi(X.min(Integer.MAX_VALUE));
            X = X.divi(X.max(Integer.MAX_VALUE));
            X = X.subiRowVector(X.mean(0));
        }


        int n = X.rows();
        // FIXME: this is wrong, another distribution required here
        Y =randn(X.rows(),targetDimensions,Nd4j.getRandom());
        INDArray dY = Nd4j.zeros(n, targetDimensions);
        INDArray iY = Nd4j.zeros(n, targetDimensions);
        INDArray gains = Nd4j.ones(n, targetDimensions);

        boolean stopLying = false;
        logger.debug("Y:Shape is = " + Arrays.toString(Y.shape()));

        // compute P-values
        INDArray P = x2p(X, tolerance, perplexity);

        // do training
        for (int i = 0; i < maxIter; i++) {
            INDArray sumY =  pow(Y, 2).sum(1).transpose();

            //Student-t distribution
            //also un normalized q
            // also known as num in original implementation
            INDArray qu = Y.mmul(Y.transpose()).muli(-2)
                    .addiRowVector(sumY)
                    .transpose()
                    .addiRowVector(sumY)
                    .addi(1)
                    .rdivi(1);

  //          doAlongDiagonal(qu,new Zero());

            INDArray  Q =  qu.div(qu.sumNumber().doubleValue());
            BooleanIndexing.applyWhere(Q, Conditions.lessThan(1e-12), new Value(1e-12));

            INDArray PQ = P.sub(Q).muli(qu);

            logger.debug("PQ shape is: " + Arrays.toString(PQ.shape()));
            logger.debug("PQ.sum(1) shape is: " + Arrays.toString(PQ.sum(1).shape()));

            dY = diag(PQ.sum(1)).subi(PQ).mmul(Y).muli(4);


            if (i < switchMomentumIteration) {
                momentum = initialMomentum;
            } else {
                momentum = finalMomentum;
            }

            gains = gains.add(.2)
                    .muli(dY.cond(Conditions.greaterThan(0)).neqi(iY.cond(Conditions.greaterThan(0))))
                    .addi(gains.mul(0.8).muli(dY.cond(Conditions.greaterThan(0)).eqi(iY.cond(Conditions.greaterThan(0)))));


            BooleanIndexing.applyWhere(gains, Conditions.lessThan(minGain), new Value(minGain));


            INDArray gradChange = gains.mul(dY);

            gradChange.muli(learningRate);

            iY.muli(momentum).subi(gradChange);

            double cost = P.mul(log(P.div(Q),false)).sumNumber().doubleValue();
            logger.info("Iteration ["+ i +"] error is: [" + cost +"]");

            Y.addi(iY);
          //  Y.addi(iY).subiRowVector(Y.mean(0));
            INDArray tiled = Nd4j.tile(Y.mean(0), new int[]{Y.rows(), 1});
            Y.subi(tiled);

            if (!stopLying && (i > maxIter / 2 || i >= stopLyingIteration)) {
                P.divi(4);
                stopLying = true;
            }
        }
        return Y;
    }

    public INDArray diag(INDArray ds) {
        boolean isLong = ds.rows() > ds.columns();
        INDArray sliceZero = ds.slice(0);
        int dim = Math.max(ds.columns(),ds.rows());
        INDArray result = Nd4j.create(dim, dim);
        for (int i = 0; i < dim; i++) {
            INDArray sliceSrc = ds.slice(i);
            INDArray sliceDst = result.slice(i);
            for (int j = 0; j < dim; j++) {
                if(i==j) {
                    if(isLong)
                        sliceDst.putScalar(j, sliceSrc.getDouble(0));
                    else
                        sliceDst.putScalar(j, sliceZero.getDouble(i));
                }
            }
        }

        return result;
    }

    public void plot(INDArray matrix, int nDims, List<String> labels, String path) throws IOException {

        calculate(matrix,nDims,perplexity);

        BufferedWriter write = new BufferedWriter(new FileWriter(new File(path),true));

        for(int i = 0; i < Y.rows(); i++) {
            if(i >= labels.size())
                break;
            String word = labels.get(i);
            if(word == null)
                continue;
            StringBuffer sb = new StringBuffer();
            INDArray wordVector = Y.getRow(i);
            for(int j = 0; j < wordVector.length(); j++) {
                sb.append(wordVector.getDouble(j));
                if(j < wordVector.length() - 1)
                    sb.append(",");
            }

            sb.append(",");
            sb.append(word);
            sb.append(" ");

            sb.append("\n");
            write.write(sb.toString());

        }

        write.flush();
        write.close();
    }

    /**
     * Computes a gaussian kernel
     * given a vector of squared distance distances
     *
     * @param d the data
     * @param beta
     * @return
     */
    public Pair<Double,INDArray> hBeta(INDArray d,double beta) {
        INDArray P =  exp(d.neg().muli(beta));
        double sumP = P.sumNumber().doubleValue();
        double logSumP = FastMath.log(sumP);
        Double H = logSumP + ((beta * (d.mul(P).sumNumber().doubleValue())) / sumP);
        P.divi(sumP);
        return new Pair<>(H,P);
    }

    /**
     * This method build probabilities for given source data
     *
     * @param X
     * @param tolerance
     * @param perplexity
     * @return
     */
    private INDArray x2p(final INDArray X, double tolerance, double perplexity) {
        int n = X.rows();
        final INDArray p = zeros(n, n);
        final INDArray beta =  ones(n, 1);
        final double logU =  Math.log(perplexity);

        INDArray sumX =  pow(X, 2).sum(1);

        logger.debug("sumX shape: " + Arrays.toString(sumX.shape()));

        INDArray times = X.mmul(X.transpose()).muli(-2);

        logger.debug("times shape: " + Arrays.toString(times.shape()));

        INDArray prodSum = times.transpose().addiColumnVector(sumX);

        logger.debug("prodSum shape: " + Arrays.toString(prodSum.shape()));

        INDArray D = X.mmul(X.transpose()).mul(-2) // thats times
                .transpose().addColumnVector(sumX) // thats prodSum
                .addRowVector(sumX.transpose()); // thats D

        logger.info("Calculating probabilities of data similarities...");
        logger.debug("Tolerance: " + tolerance);
        for(int i = 0; i < n; i++) {
            if(i % 500 == 0 && i > 0)
                logger.info("Handled [" + i + "] records out of ["+ n +"]");

            double betaMin = Double.NEGATIVE_INFINITY;
            double betaMax = Double.POSITIVE_INFINITY;
            int[] vals = Ints.concat(ArrayUtil.range(0,i),ArrayUtil.range(i + 1, n ));
            INDArrayIndex[] range = new INDArrayIndex[]{new SpecifiedIndex(vals)};

            INDArray row = D.slice(i).get(range);
            Pair<Double,INDArray> pair =  hBeta(row,beta.getDouble(i));
            //INDArray hDiff = pair.getFirst().sub(logU);
            double hDiff = pair.getFirst() - logU;
            int tries = 0;

            //while hdiff > tolerance
            while(Math.abs(hDiff) > tolerance && tries < 50) {
                //if hdiff > 0
                if(hDiff > 0) {
                    betaMin = beta.getDouble(i);
                    if(Double.isInfinite(betaMax))
                        beta.putScalar(i,beta.getDouble(i) * 2.0);
                    else
                        beta.putScalar(i,(beta.getDouble(i) + betaMax) / 2.0);
                } else {
                    betaMax = beta.getDouble(i);
                    if(Double.isInfinite(betaMin))
                        beta.putScalar(i,beta.getDouble(i) / 2.0);
                    else
                        beta.putScalar(i,(beta.getDouble(i) + betaMin) / 2.0);
                }

                pair = hBeta(row,beta.getDouble(i));
                hDiff = pair.getFirst() - logU;
                tries++;
            }
            p.slice(i).put(range,pair.getSecond());
        }


        //dont need data in memory after
        logger.info("Mean value of sigma " + sqrt(beta.rdiv(1)).mean(Integer.MAX_VALUE));
        BooleanIndexing.applyWhere(p,Conditions.isNan(),new Value(1e-12));

        //set 0 along the diagonal
        INDArray permute = p.transpose();

        INDArray pOut = p.add(permute);

        pOut.divi(pOut.sumNumber().doubleValue() + 1e-6);

        pOut.muli(4);

        BooleanIndexing.applyWhere(pOut,Conditions.lessThan(1e-12),new Value(1e-12));
        //ensure no nans

        return pOut;
    }


    public static class Builder {
        protected int maxIter = 1000;
        protected double realMin = 1e-12f;
        protected double initialMomentum = 5e-1f;
        protected double finalMomentum = 8e-1f;
        protected double momentum = 5e-1f;
        protected int switchMomentumIteration = 100;
        protected boolean normalize = true;
        protected boolean usePca = false;
        protected int stopLyingIteration = 100;
        protected double tolerance = 1e-5f;
        protected double learningRate = 1e-1f;
        protected boolean useAdaGrad = false;
        protected double perplexity = 30;
        protected double minGain = 1e-1f;


        public Builder minGain(double minGain) {
            this.minGain = minGain;
            return this;
        }

        public Builder perplexity(double perplexity) {
            this.perplexity = perplexity;
            return this;
        }

        public Builder useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder learningRate(double learningRate) {
            this.learningRate = learningRate;
            return this;
        }


        public Builder tolerance(double tolerance) {
            this.tolerance = tolerance;
            return this;
        }

        public Builder stopLyingIteration(int stopLyingIteration) {
            this.stopLyingIteration = stopLyingIteration;
            return this;
        }

        public Builder usePca(boolean usePca) {
            this.usePca = usePca;
            return this;
        }

        public Builder normalize(boolean normalize) {
            this.normalize = normalize;
            return this;
        }

        public Builder setMaxIter(int maxIter) {
            this.maxIter = maxIter;
            return this;
        }

        public Builder setRealMin(double realMin) {
            this.realMin = realMin;
            return this;
        }

        public Builder setInitialMomentum(double initialMomentum) {
            this.initialMomentum = initialMomentum;
            return this;
        }

        public Builder setFinalMomentum(double finalMomentum) {
            this.finalMomentum = finalMomentum;
            return this;
        }

        public Builder setMomentum(double momentum) {
            this.momentum = momentum;
            return this;
        }

        public Builder setSwitchMomentumIteration(int switchMomentumIteration) {
            this.switchMomentumIteration = switchMomentumIteration;
            return this;
        }

        public Tsne build() {
            Tsne tsne = new Tsne();
            tsne.finalMomentum = this.finalMomentum;
            tsne.initialMomentum = this.initialMomentum;
            tsne.maxIter = this.maxIter;
            tsne.learningRate = this.learningRate;
            tsne.minGain = this.minGain;
            tsne.momentum = this.momentum;
            tsne.normalize = this.normalize;
            tsne.perplexity = this.perplexity;
            tsne.tolerance = this.tolerance;
            tsne.realMin = this.realMin;
            tsne.stopLyingIteration = this.stopLyingIteration;
            tsne.switchMomentumIteration = this.switchMomentumIteration;
            tsne.usePca = this.usePca;
            tsne.useAdaGrad = this.useAdaGrad;

            tsne.init();

            return tsne;
        }
    }
}

Other Java examples (source code examples)

Here is a short list of links related to this Java Tsne.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.