alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (FuzzyKMeansClusterer.java)

This example Java source code file (FuzzyKMeansClusterer.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

arraylist, centroidcluster, clusterable, default_epsilon, distancemeasure, doublepoint, euclideandistance, fuzzykmeansclusterer, list, mathillegalstateexception, numberistoosmallexception, override, randomgenerator, realmatrix, util

The FuzzyKMeansClusterer.java Java example source code

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.math3.ml.clustering;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.exception.MathIllegalStateException;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.MathArrays;
import org.apache.commons.math3.util.MathUtils;

/**
 * Fuzzy K-Means clustering algorithm.
 * <p>
 * The Fuzzy K-Means algorithm is a variation of the classical K-Means algorithm, with the
 * major difference that a single data point is not uniquely assigned to a single cluster.
 * Instead, each point i has a set of weights u<sub>ij which indicate the degree of membership
 * to the cluster j.
 * <p>
 * The algorithm then tries to minimize the objective function:
 * <pre>
 * J = ∑<sub>i=1..C∑k=1..N uikmdik2
 * </pre>
 * with d<sub>ik being the distance between data point i and the cluster center k.
 * <p>
 * The algorithm requires two parameters:
 * <ul>
 *   <li>k: the number of clusters
 *   <li>fuzziness: determines the level of cluster fuzziness, larger values lead to fuzzier clusters
 * </ul>
 * Additional, optional parameters:
 * <ul>
 *   <li>maxIterations: the maximum number of iterations
 *   <li>epsilon: the convergence criteria, default is 1e-3
 * </ul>
 * <p>
 * The fuzzy variant of the K-Means algorithm is more robust with regard to the selection
 * of the initial cluster centers.
 *
 * @param <T> type of the points to cluster
 * @since 3.3
 */
public class FuzzyKMeansClusterer<T extends Clusterable> extends Clusterer {

    /** The default value for the convergence criteria. */
    private static final double DEFAULT_EPSILON = 1e-3;

    /** The number of clusters. */
    private final int k;

    /** The maximum number of iterations. */
    private final int maxIterations;

    /** The fuzziness factor. */
    private final double fuzziness;

    /** The convergence criteria. */
    private final double epsilon;

    /** Random generator for choosing initial centers. */
    private final RandomGenerator random;

    /** The membership matrix. */
    private double[][] membershipMatrix;

    /** The list of points used in the last call to {@link #cluster(Collection)}. */
    private List<T> points;

    /** The list of clusters resulting from the last call to {@link #cluster(Collection)}. */
    private List<CentroidCluster clusters;

    /**
     * Creates a new instance of a FuzzyKMeansClusterer.
     * <p>
     * The euclidean distance will be used as default distance measure.
     *
     * @param k the number of clusters to split the data into
     * @param fuzziness the fuzziness factor, must be > 1.0
     * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0}
     */
    public FuzzyKMeansClusterer(final int k, final double fuzziness) throws NumberIsTooSmallException {
        this(k, fuzziness, -1, new EuclideanDistance());
    }

    /**
     * Creates a new instance of a FuzzyKMeansClusterer.
     *
     * @param k the number of clusters to split the data into
     * @param fuzziness the fuzziness factor, must be > 1.0
     * @param maxIterations the maximum number of iterations to run the algorithm for.
     *   If negative, no maximum will be used.
     * @param measure the distance measure to use
     * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0}
     */
    public FuzzyKMeansClusterer(final int k, final double fuzziness,
                                final int maxIterations, final DistanceMeasure measure)
            throws NumberIsTooSmallException {
        this(k, fuzziness, maxIterations, measure, DEFAULT_EPSILON, new JDKRandomGenerator());
    }

    /**
     * Creates a new instance of a FuzzyKMeansClusterer.
     *
     * @param k the number of clusters to split the data into
     * @param fuzziness the fuzziness factor, must be > 1.0
     * @param maxIterations the maximum number of iterations to run the algorithm for.
     *   If negative, no maximum will be used.
     * @param measure the distance measure to use
     * @param epsilon the convergence criteria (default is 1e-3)
     * @param random random generator to use for choosing initial centers
     * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0}
     */
    public FuzzyKMeansClusterer(final int k, final double fuzziness,
                                final int maxIterations, final DistanceMeasure measure,
                                final double epsilon, final RandomGenerator random)
            throws NumberIsTooSmallException {

        super(measure);

        if (fuzziness <= 1.0d) {
            throw new NumberIsTooSmallException(fuzziness, 1.0, false);
        }
        this.k = k;
        this.fuzziness = fuzziness;
        this.maxIterations = maxIterations;
        this.epsilon = epsilon;
        this.random = random;

        this.membershipMatrix = null;
        this.points = null;
        this.clusters = null;
    }

    /**
     * Return the number of clusters this instance will use.
     * @return the number of clusters
     */
    public int getK() {
        return k;
    }

    /**
     * Returns the fuzziness factor used by this instance.
     * @return the fuzziness factor
     */
    public double getFuzziness() {
        return fuzziness;
    }

    /**
     * Returns the maximum number of iterations this instance will use.
     * @return the maximum number of iterations, or -1 if no maximum is set
     */
    public int getMaxIterations() {
        return maxIterations;
    }

    /**
     * Returns the convergence criteria used by this instance.
     * @return the convergence criteria
     */
    public double getEpsilon() {
        return epsilon;
    }

    /**
     * Returns the random generator this instance will use.
     * @return the random generator
     */
    public RandomGenerator getRandomGenerator() {
        return random;
    }

    /**
     * Returns the {@code nxk} membership matrix, where {@code n} is the number
     * of data points and {@code k} the number of clusters.
     * <p>
     * The element U<sub>i,j represents the membership value for data point {@code i}
     * to cluster {@code j}.
     *
     * @return the membership matrix
     * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before
     */
    public RealMatrix getMembershipMatrix() {
        if (membershipMatrix == null) {
            throw new MathIllegalStateException();
        }
        return MatrixUtils.createRealMatrix(membershipMatrix);
    }

    /**
     * Returns an unmodifiable list of the data points used in the last
     * call to {@link #cluster(Collection)}.
     * @return the list of data points, or {@code null} if {@link #cluster(Collection)} has
     *   not been called before.
     */
    public List<T> getDataPoints() {
        return points;
    }

    /**
     * Returns the list of clusters resulting from the last call to {@link #cluster(Collection)}.
     * @return the list of clusters, or {@code null} if {@link #cluster(Collection)} has
     *   not been called before.
     */
    public List<CentroidCluster getClusters() {
        return clusters;
    }

    /**
     * Get the value of the objective function.
     * @return the objective function evaluation as double value
     * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before
     */
    public double getObjectiveFunctionValue() {
        if (points == null || clusters == null) {
            throw new MathIllegalStateException();
        }

        int i = 0;
        double objFunction = 0.0;
        for (final T point : points) {
            int j = 0;
            for (final CentroidCluster<T> cluster : clusters) {
                final double dist = distance(point, cluster.getCenter());
                objFunction += (dist * dist) * FastMath.pow(membershipMatrix[i][j], fuzziness);
                j++;
            }
            i++;
        }
        return objFunction;
    }

    /**
     * Performs Fuzzy K-Means cluster analysis.
     *
     * @param dataPoints the points to cluster
     * @return the list of clusters
     * @throws MathIllegalArgumentException if the data points are null or the number
     *     of clusters is larger than the number of data points
     */
    @Override
    public List<CentroidCluster cluster(final Collection dataPoints)
            throws MathIllegalArgumentException {

        // sanity checks
        MathUtils.checkNotNull(dataPoints);

        final int size = dataPoints.size();

        // number of clusters has to be smaller or equal the number of data points
        if (size < k) {
            throw new NumberIsTooSmallException(size, k, false);
        }

        // copy the input collection to an unmodifiable list with indexed access
        points = Collections.unmodifiableList(new ArrayList<T>(dataPoints));
        clusters = new ArrayList<CentroidCluster();
        membershipMatrix = new double[size][k];
        final double[][] oldMatrix = new double[size][k];

        // if no points are provided, return an empty list of clusters
        if (size == 0) {
            return clusters;
        }

        initializeMembershipMatrix();

        // there is at least one point
        final int pointDimension = points.get(0).getPoint().length;
        for (int i = 0; i < k; i++) {
            clusters.add(new CentroidCluster<T>(new DoublePoint(new double[pointDimension])));
        }

        int iteration = 0;
        final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
        double difference = 0.0;

        do {
            saveMembershipMatrix(oldMatrix);
            updateClusterCenters();
            updateMembershipMatrix();
            difference = calculateMaxMembershipChange(oldMatrix);
        } while (difference > epsilon && ++iteration < max);

        return clusters;
    }

    /**
     * Update the cluster centers.
     */
    private void updateClusterCenters() {
        int j = 0;
        final List<CentroidCluster newClusters = new ArrayList>(k);
        for (final CentroidCluster<T> cluster : clusters) {
            final Clusterable center = cluster.getCenter();
            int i = 0;
            double[] arr = new double[center.getPoint().length];
            double sum = 0.0;
            for (final T point : points) {
                final double u = FastMath.pow(membershipMatrix[i][j], fuzziness);
                final double[] pointArr = point.getPoint();
                for (int idx = 0; idx < arr.length; idx++) {
                    arr[idx] += u * pointArr[idx];
                }
                sum += u;
                i++;
            }
            MathArrays.scaleInPlace(1.0 / sum, arr);
            newClusters.add(new CentroidCluster<T>(new DoublePoint(arr)));
            j++;
        }
        clusters.clear();
        clusters = newClusters;
    }

    /**
     * Updates the membership matrix and assigns the points to the cluster with
     * the highest membership.
     */
    private void updateMembershipMatrix() {
        for (int i = 0; i < points.size(); i++) {
            final T point = points.get(i);
            double maxMembership = Double.MIN_VALUE;
            int newCluster = -1;
            for (int j = 0; j < clusters.size(); j++) {
                double sum = 0.0;
                final double distA = FastMath.abs(distance(point, clusters.get(j).getCenter()));

                if (distA != 0.0) {
                    for (final CentroidCluster<T> c : clusters) {
                        final double distB = FastMath.abs(distance(point, c.getCenter()));
                        if (distB == 0.0) {
                            sum = Double.POSITIVE_INFINITY;
                            break;
                        }
                        sum += FastMath.pow(distA / distB, 2.0 / (fuzziness - 1.0));
                    }
                }

                double membership;
                if (sum == 0.0) {
                    membership = 1.0;
                } else if (sum == Double.POSITIVE_INFINITY) {
                    membership = 0.0;
                } else {
                    membership = 1.0 / sum;
                }
                membershipMatrix[i][j] = membership;

                if (membershipMatrix[i][j] > maxMembership) {
                    maxMembership = membershipMatrix[i][j];
                    newCluster = j;
                }
            }
            clusters.get(newCluster).addPoint(point);
        }
    }

    /**
     * Initialize the membership matrix with random values.
     */
    private void initializeMembershipMatrix() {
        for (int i = 0; i < points.size(); i++) {
            for (int j = 0; j < k; j++) {
                membershipMatrix[i][j] = random.nextDouble();
            }
            membershipMatrix[i] = MathArrays.normalizeArray(membershipMatrix[i], 1.0);
        }
    }

    /**
     * Calculate the maximum element-by-element change of the membership matrix
     * for the current iteration.
     *
     * @param matrix the membership matrix of the previous iteration
     * @return the maximum membership matrix change
     */
    private double calculateMaxMembershipChange(final double[][] matrix) {
        double maxMembership = 0.0;
        for (int i = 0; i < points.size(); i++) {
            for (int j = 0; j < clusters.size(); j++) {
                double v = FastMath.abs(membershipMatrix[i][j] - matrix[i][j]);
                maxMembership = FastMath.max(v, maxMembership);
            }
        }
        return maxMembership;
    }

    /**
     * Copy the membership matrix into the provided matrix.
     *
     * @param matrix the place to store the membership matrix
     */
    private void saveMembershipMatrix(final double[][] matrix) {
        for (int i = 0; i < points.size(); i++) {
            System.arraycopy(membershipMatrix[i], 0, matrix[i], 0, clusters.size());
        }
    }

}

Other Java examples (source code examples)

Here is a short list of links related to this Java FuzzyKMeansClusterer.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.