前面的三种基于链接的算法都是“合并”,k-means算法是一种“分割”算法,据说是实际应用最广泛的聚类算法。
4.4.1 初识k-means算法
主程序:
package com.hankcs; import iweb2.ch4.clustering.partitional.KMeansAlgorithm; import iweb2.ch4.data.SFData; import iweb2.ch4.data.SFDataset; import iweb2.ch4.model.DataPoint; public class ch4_3_Partitional { public static void main(String[] args) throws Exception { SFDataset ds = SFData.createDataset(); DataPoint[] dps = ds.getData(); KMeansAlgorithm kMeans = new KMeansAlgorithm(8, dps); kMeans.cluster(); kMeans.print(); } }
输出:
From file: C:/iWeb2/data/ch04/clusteringSF.dat Using attribute names: [Age, IncomeRange, Education, Skills, Social, isPaid] Loaded 20 data points. Clusters: {Charlie, Carl, Bill} {Aurora, Alexandra} {Catherine, Bob} {Elena, Eric, Frank, Dmitry, Constantine} {John} {Athena, Babis, Albert} {Lukas} {Jack, Maria, George}
k-means的k指的是将所有元素分割为k个聚类,算法思想的核心是根据所有元素的“重心”展开的。
4.4.2 k-means的内部原理
k-means中的means指的是重心的意思,该算法首先随机挑选k个点作为初始的重心:
/** * 随机建立重心点集 * @param k * @param data * @return */ public static DataPoint[] pickInitialCentroids(int k, DataPoint[] data) { Random randGen = new Random(); DataPoint[] centroids = new DataPoint[k]; // Calculate random mean values for each cluster based on the data /** * TODO: 4.2 -- Selecting the means for seeding * * In large datasets, the selection of the initial centroids can be * important from a computational (time) complexity perspective. * * In general, how can we improve the seeding of the initial mean values? * For example, consider the following heuristic: * * 1. pick randomly one node * 2. calculate the distance between that node and O (10*k) other nodes * 3. sort the list of nodes according to their distance from the first node * 4. pick every 10th node in the sequence * 5. calculate the mean distance between each one of these nodes and the original node * * This algorithmic choice is as ad hoc as they come, however, it does have * some key principles embedded in it? What are these principles? * How can you generalize this algorithm? * * Discuss advantages/disadvantages of the initial seeding with your friends. * */ Set<Integer> previouslyUsedIds = new HashSet<Integer>(); for (int i = 0; i < k; i++) { // pick point index that we haven't used yet // 选出还未用到的点集 int centroidId; do { centroidId = randGen.nextInt(data.length); } while (previouslyUsedIds.add(centroidId) == false); // Create DataPoint that will represent the cluster's centroid. String label = "Mean-" + i + "(" + data[centroidId].getLabel() + ")"; double[] values = data[centroidId].getNumericAttrValues(); String[] attrNames = data[centroidId].getAttributeNames(); centroids[i] = new DataPoint(label, Attributes.createAttributes(attrNames, values)); } return centroids; }
然后开始对所有点找出离它最近的重心:
/** * This method calculates the closest centroid for a given data point * 这个方法计算离给定数据点最近的重心 * * @param centroids * @param x is the <CODE>DataPoint</CODE> for which we seek the closest centroid * @return the index (from the centroids array) of the closest centroid * 重心在centroids中的下标 */ private int findClosestCentroid(DataPoint[] centroids, DataPoint x) { double minDistance = Double.POSITIVE_INFINITY; int closestCentroid = -1; for (int i = 0, n = centroids.length; i < n; i++) { double d = distance(centroids[i], x); // if the d == minDistance then keep current selection if (d < minDistance) { minDistance = d; closestCentroid = i; } } return closestCentroid; }
然后将该点加入该重心所属的聚类,这会导致重心发生改变,于是重新计算重心,然后重复上一步,直到重心不再发生改变(当没有点可供加入的时候重心自然就不会变化了)。
public void cluster() { boolean centroidsChanged = true; while (centroidsChanged == true) { // Create a set points for each cluster // 创建每个聚类的数据点集合 List<Set<DataPoint>> clusters = new ArrayList<Set<DataPoint>>(k); for (int i = 0; i < k; i++) { clusters.add(new HashSet<DataPoint>()); } // Assign points to each set based on minimum distance from the centroids // 根据到重心的最短距离为每个聚类分配数据点 for (DataPoint p : allDataPoints) { int i = findClosestCentroid(allCentroids, p); clusters.get(i).add(p); } for (int i = 0; i < k; i++) { // 创建聚类 allClusters[i] = new Cluster(clusters.get(i)); } // Calculate new cluster centroids, and // check if any of the centroids has changed centroidsChanged = false; for (int i = 0; i < allClusters.length; i++) { if (clusters.get(i).size() > 0) { double[] newCentroidValues = findCentroid(allClusters[i]); double[] oldCentroidValues = allCentroids[i].getNumericAttrValues(); if (!Arrays.equals(oldCentroidValues, newCentroidValues)) { allCentroids[i] = new DataPoint(allCentroids[i].getLabel(), newCentroidValues); centroidsChanged = true; } } else { // keep mean unchanged if cluster has no elements. } } } }