alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Commons Math example source code file (KMeansPlusPlusClusterer.java)

This example Commons Math source code file (KMeansPlusPlusClusterer.java) is included in the DevDaily.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Java - Commons Math tags/keywords

arraylist, cluster, cluster, clusterable, collection, collection, kmeansplusplusclusterer, list, list, random, t, t, util

The Commons Math KMeansPlusPlusClusterer.java source code

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.commons.math.stat.clustering;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;

/**
 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
 * @param <T> type of the points to cluster
 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)
 * @version $Revision: 811685 $ $Date: 2009-09-05 13:36:48 -0400 (Sat, 05 Sep 2009) $
 * @since 2.0
 */
public class KMeansPlusPlusClusterer<T extends Clusterable {

    /** Random generator for choosing initial centers. */
    private final Random random;

    /** Build a clusterer.
     * @param random random generator to use for choosing initial centers
     */
    public KMeansPlusPlusClusterer(final Random random) {
        this.random = random;
    }

    /**
     * Runs the K-means++ clustering algorithm.
     *
     * @param points the points to cluster
     * @param k the number of clusters to split the data into
     * @param maxIterations the maximum number of iterations to run the algorithm
     *     for.  If negative, no maximum will be used
     * @return a list of clusters containing the points
     */
    public List<Cluster cluster(final Collection points,
                                    final int k, final int maxIterations) {
        // create the initial clusters
        List<Cluster clusters = chooseInitialCenters(points, k, random);
        assignPointsToClusters(clusters, points);

        // iterate through updating the centers until we're done
        final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
        for (int count = 0; count < max; count++) {
            boolean clusteringChanged = false;
            List<Cluster newClusters = new ArrayList>();
            for (final Cluster<T> cluster : clusters) {
                final T newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
                if (!newCenter.equals(cluster.getCenter())) {
                    clusteringChanged = true;
                }
                newClusters.add(new Cluster<T>(newCenter));
            }
            if (!clusteringChanged) {
                return clusters;
            }
            assignPointsToClusters(newClusters, points);
            clusters = newClusters;
        }
        return clusters;
    }

    /**
     * Adds the given points to the closest {@link Cluster}.
     *
     * @param <T> type of the points to cluster
     * @param clusters the {@link Cluster}s to add the points to
     * @param points the points to add to the given {@link Cluster}s
     */
    private static <T extends Clusterable void
        assignPointsToClusters(final Collection<Cluster clusters, final Collection points) {
        for (final T p : points) {
            Cluster<T> cluster = getNearestCluster(clusters, p);
            cluster.addPoint(p);
        }
    }

    /**
     * Use K-means++ to choose the initial centers.
     *
     * @param <T> type of the points to cluster
     * @param points the points to choose the initial centers from
     * @param k the number of centers to choose
     * @param random random generator to use
     * @return the initial centers
     */
    private static <T extends Clusterable List>
        chooseInitialCenters(final Collection<T> points, final int k, final Random random) {

        final List<T> pointSet = new ArrayList(points);
        final List<Cluster resultSet = new ArrayList>();

        // Choose one center uniformly at random from among the data points.
        final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
        resultSet.add(new Cluster<T>(firstPoint));

        final double[] dx2 = new double[pointSet.size()];
        while (resultSet.size() < k) {
            // For each data point x, compute D(x), the distance between x and
            // the nearest center that has already been chosen.
            int sum = 0;
            for (int i = 0; i < pointSet.size(); i++) {
                final T p = pointSet.get(i);
                final Cluster<T> nearest = getNearestCluster(resultSet, p);
                final double d = p.distanceFrom(nearest.getCenter());
                sum += d * d;
                dx2[i] = sum;
            }

            // Add one new data point as a center. Each point x is chosen with
            // probability proportional to D(x)2
            final double r = random.nextDouble() * sum;
            for (int i = 0 ; i < dx2.length; i++) {
                if (dx2[i] >= r) {
                    final T p = pointSet.remove(i);
                    resultSet.add(new Cluster<T>(p));
                    break;
                }
            }
        }

        return resultSet;

    }

    /**
     * Returns the nearest {@link Cluster} to the given point
     *
     * @param <T> type of the points to cluster
     * @param clusters the {@link Cluster}s to search
     * @param point the point to find the nearest {@link Cluster} for
     * @return the nearest {@link Cluster} to the given point
     */
    private static <T extends Clusterable Cluster
        getNearestCluster(final Collection<Cluster clusters, final T point) {
        double minDistance = Double.MAX_VALUE;
        Cluster<T> minCluster = null;
        for (final Cluster<T> c : clusters) {
            final double distance = point.distanceFrom(c.getCenter());
            if (distance < minDistance) {
                minDistance = distance;
                minCluster = c;
            }
        }
        return minCluster;
    }

}

Other Commons Math examples (source code examples)

Here is a short list of links related to this Commons Math KMeansPlusPlusClusterer.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.