alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (SpTree.java)

This example Java source code file (SpTree.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

atomicdouble, cell, hashset, illegalargumentexception, indarray, inserted, node_ratio, rowp, set, sptree, string, util

The SpTree.java Java example source code

/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.clustering.sptree;

import com.google.common.util.concurrent.AtomicDouble;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.HashSet;
import java.util.Set;


/**
 * @author Adam Gibson
 */
public class SpTree implements Serializable {
    private int D;private INDArray data;
    public final static int NODE_RATIO = 8000;
    private int N;
    private INDArray buf;
    private int size;
    private int cumSize;
    private Cell boundary;
    private INDArray centerOfMass;
    private SpTree parent;
    private int[] index;
    private int nodeCapacity;
    private int numChildren = 2;
    private boolean isLeaf = true;
    private Set<INDArray> indices;
    private SpTree[] children;
    private static Logger log = LoggerFactory.getLogger(SpTree.class);
    private String similarityFunction = "euclidean";


    public SpTree(SpTree parent,INDArray data,INDArray corner,INDArray width,Set<INDArray> indices,String similarityFunction) {
        init(parent, data, corner, width,indices,similarityFunction);
    }


    public SpTree(INDArray data,Set<INDArray> indices,String similarityFunction) {
        this.indices = indices;
        this.N = data.rows();
        this.D = data.columns();
        this.similarityFunction = similarityFunction;
        INDArray meanY = data.mean(0);
        INDArray minY = data.min(0);
        INDArray maxY = data.max(0);
        INDArray width = Nd4j.create(meanY.shape());
        for(int i = 0; i < width.length(); i++) {
            width.putScalar(i, FastMath.max(maxY.getDouble(i) - meanY.getDouble(i),meanY.getDouble(i) - minY.getDouble(i) + Nd4j.EPS_THRESHOLD));
        }

        init(null,data,meanY,width,indices,similarityFunction);
        fill(N);


    }


    public SpTree(SpTree parent,INDArray data,INDArray corner,INDArray width,Set<INDArray> indices) {
        this(parent, data, corner, width,indices,"euclidean");
    }


    public SpTree(INDArray data,Set<INDArray> indices) {
        this(data,indices,"euclidean");
    }



    public SpTree(INDArray data) {
        this(data, new HashSet<INDArray>());
    }

    private void init(SpTree parent,INDArray data,INDArray corner,INDArray width,Set<INDArray> indices,String similarityFunction) {
        this.parent = parent;
        D = data.columns();
        N = data.rows();
        this.similarityFunction = similarityFunction;
        nodeCapacity = N % NODE_RATIO;
        index = new int[nodeCapacity];
        for(int d = 1; d < this.D; d++)
            numChildren *= 2;
        this.indices = indices;
        isLeaf = true;
        size = 0;
        cumSize = 0;
        children = new SpTree[numChildren];
        this.data = data;
        boundary = new Cell(D);
        boundary.setCorner(corner.dup());
        boundary.setWidth(width.dup());
        centerOfMass = Nd4j.create(D);
        buf = Nd4j.create(D);
    }




    private boolean insert(int index) {
        INDArray point = data.slice(index);
        if(!boundary.contains(point))
            return false;


        cumSize++;
        double mult1 = (double) (cumSize - 1) / (double) cumSize;
        double mult2 = 1.0 / (double) cumSize;
        centerOfMass.muli(mult1);
        centerOfMass.addi(point.mul(mult2));
        // If there is space in this quad tree and it is a leaf, add the object here
        if(isLeaf() && size < nodeCapacity) {
            this.index[size] = index;
            indices.add(point);
            size++;
            return true;
        }


        for(int i = 0; i < size; i++) {
            INDArray compPoint = data.slice(this.index[i]);
            if(compPoint.equals(point))
                return true;
        }


        if(isLeaf())
            subDivide();


        // Find out where the point can be inserted
        for(int i = 0; i < numChildren; i++) {
            if(children[i].insert(index))
                return true;
        }

        throw new IllegalStateException("Shouldn't reach this state");
    }


    /**
     * Subdivide the node in to
     * 4 children
     */
    public void subDivide() {
        INDArray newCorner = Nd4j.create(D);
        INDArray newWidth = Nd4j.create(D);
        for( int i = 0; i < numChildren; i++) {
            int div = 1;
            for( int d = 0; d < D; d++) {
                newWidth.putScalar(d,.5 * boundary.width(d));
                if((i / div) % 2 == 1)
                    newCorner.putScalar(d, boundary.corner(d) - .5 * boundary.width(d));
                else
                    newCorner.putScalar(d,boundary.corner(d) + .5 * boundary.width(d));
                div *= 2;
            }

            children[i] = new SpTree(this,data, newCorner, newWidth,indices);

        }

        // Move existing points to correct children
        for(int i = 0; i < size; i++) {
            boolean success = false;
            for(int j = 0; j < this.numChildren; j++)
                if(!success)
                    success = children[j].insert(index[i]);

            index[i] = -1;
        }

        // Empty parent node
        size = 0;
        isLeaf = false;
    }



    /**
     * Compute non edge forces using barnes hut
     * @param pointIndex
     * @param theta
     * @param negativeForce
     * @param sumQ
     */
    public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) {
        // Make sure that we spend no time on empty nodes or self-interactions
        if(cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex))
            return;


        // Compute distance between point and center-of-mass
        buf.assign(data.slice(pointIndex)).subi(centerOfMass);

        double D = Nd4j.getBlasWrapper().dot(buf, buf);
        // Check whether we can use this node as a "summary"
        double maxWidth = boundary.width().max(Integer.MAX_VALUE).getDouble(0);
        // Check whether we can use this node as a "summary"
        if(isLeaf() || maxWidth / FastMath.sqrt(D) < theta) {

            // Compute and add t-SNE force between point and current node
            double Q = 1.0 / (1.0 + D);
            double mult = cumSize * Q;
            sumQ.addAndGet(mult);
            mult *= Q;
            negativeForce.addi(buf.mul(mult));

        }
        else {

            // Recursively apply Barnes-Hut to children
            for(int i = 0; i < numChildren; i++) {
                children[i].computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
            }

        }
    }


    /**
     *
     * Compute edge forces using barns hut
     * @param rowP a vector
     * @param colP
     * @param valP
     * @param N the number of elements
     * @param posF the positive force
     */
    public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) {
        if(!rowP.isVector())
            throw new IllegalArgumentException("RowP must be a vector");

        // Loop over all edges in the graph
        double D;
        for(int n = 0; n < N; n++) {
            for(int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) {

                // Compute pairwise distance and Q-value
                buf.assign(data.slice(n)).subi(data.slice(colP.getInt(i)));

                D = Nd4j.getBlasWrapper().dot(buf,buf);
                D = valP.getDouble(i) / D;

                // Sum positive force
                posF.slice(n).addi(buf.mul(D));

            }
        }
    }



    public boolean isLeaf() {
        return isLeaf;
    }

    /**
     * Verifies the structure of the tree (does bounds checking on each node)
     * @return true if the structure of the tree
     * is correct.
     */
    public boolean isCorrect() {
        for(int n = 0; n < size; n++) {
            INDArray point = data.slice(index[n]);
            if(!boundary.contains(point))
                return false;
        }
        if(!isLeaf()) {
            boolean correct = true;
            for(int i = 0; i < numChildren; i++)
                correct = correct && children[i].isCorrect();
            return correct;
        }

        return true;
    }

    /**
     * The depth of the node
     * @return the depth of the node
     */
    public int depth() {
        if(isLeaf())
            return 1;
        int depth = 1;
        int maxChildDepth = 0;
        for(int i = 0; i < numChildren; i++) {
            maxChildDepth = Math.max(maxChildDepth, children[0].depth());
        }

        return depth + maxChildDepth;
    }

    private void fill(int n) {
        if(indices.isEmpty() && parent == null)
            for(int i = 0; i < n; i++) {
                log.trace("Inserted " + i);
                insert(i);
            }
        else
            log.warn("Called fill already");
    }


    public SpTree[] getChildren() {
        return children;
    }

    public int getD() {
        return D;
    }

    public INDArray getCenterOfMass() {
        return centerOfMass;
    }

    public Cell getBoundary() {
        return boundary;
    }

    public int[] getIndex() {
        return index;
    }

    public int getCumSize() {
        return cumSize;
    }

    public void setCumSize(int cumSize) {
        this.cumSize = cumSize;
    }

    public int getNumChildren() {
        return numChildren;
    }

    public void setNumChildren(int numChildren) {
        this.numChildren = numChildren;
    }

}

Other Java examples (source code examples)

Here is a short list of links related to this Java SpTree.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.