
前面的三种基于链接的算法都是“合并”,k-means算法是一种“分割”算法,据说是实际应用最广泛的聚类算法。
4.4.1 初识k-means算法
主程序:
package com.hankcs;
import iweb2.ch4.clustering.partitional.KMeansAlgorithm;
import iweb2.ch4.data.SFData;
import iweb2.ch4.data.SFDataset;
import iweb2.ch4.model.DataPoint;
public class ch4_3_Partitional
{
public static void main(String[] args) throws Exception
{
SFDataset ds = SFData.createDataset();
DataPoint[] dps = ds.getData();
KMeansAlgorithm kMeans = new KMeansAlgorithm(8, dps);
kMeans.cluster();
kMeans.print();
}
}
输出:
From file: C:/iWeb2/data/ch04/clusteringSF.dat
Using attribute names: [Age, IncomeRange, Education, Skills, Social, isPaid]
Loaded 20 data points.
Clusters:
{Charlie,
Carl,
Bill}
{Aurora,
Alexandra}
{Catherine,
Bob}
{Elena,
Eric,
Frank,
Dmitry,
Constantine}
{John}
{Athena,
Babis,
Albert}
{Lukas}
{Jack,
Maria,
George}
k-means的k指的是将所有元素分割为k个聚类,算法思想的核心是根据所有元素的“重心”展开的。
4.4.2 k-means的内部原理
k-means中的means指的是重心的意思,该算法首先随机挑选k个点作为初始的重心:
/**
* 随机建立重心点集
* @param k
* @param data
* @return
*/
public static DataPoint[] pickInitialCentroids(int k, DataPoint[] data)
{
Random randGen = new Random();
DataPoint[] centroids = new DataPoint[k];
// Calculate random mean values for each cluster based on the data
/**
* TODO: 4.2 -- Selecting the means for seeding
*
* In large datasets, the selection of the initial centroids can be
* important from a computational (time) complexity perspective.
*
* In general, how can we improve the seeding of the initial mean values?
* For example, consider the following heuristic:
*
* 1. pick randomly one node
* 2. calculate the distance between that node and O (10*k) other nodes
* 3. sort the list of nodes according to their distance from the first node
* 4. pick every 10th node in the sequence
* 5. calculate the mean distance between each one of these nodes and the original node
*
* This algorithmic choice is as ad hoc as they come, however, it does have
* some key principles embedded in it? What are these principles?
* How can you generalize this algorithm?
*
* Discuss advantages/disadvantages of the initial seeding with your friends.
*
*/
Set<Integer> previouslyUsedIds = new HashSet<Integer>();
for (int i = 0; i < k; i++)
{
// pick point index that we haven't used yet
// 选出还未用到的点集
int centroidId;
do
{
centroidId = randGen.nextInt(data.length);
}
while (previouslyUsedIds.add(centroidId) == false);
// Create DataPoint that will represent the cluster's centroid.
String label = "Mean-" + i + "(" + data[centroidId].getLabel() + ")";
double[] values = data[centroidId].getNumericAttrValues();
String[] attrNames = data[centroidId].getAttributeNames();
centroids[i] = new DataPoint(label,
Attributes.createAttributes(attrNames, values));
}
return centroids;
}
然后开始对所有点找出离它最近的重心:
/**
* This method calculates the closest centroid for a given data point
* 这个方法计算离给定数据点最近的重心
*
* @param centroids
* @param x is the <CODE>DataPoint</CODE> for which we seek the closest centroid
* @return the index (from the centroids array) of the closest centroid
* 重心在centroids中的下标
*/
private int findClosestCentroid(DataPoint[] centroids, DataPoint x)
{
double minDistance = Double.POSITIVE_INFINITY;
int closestCentroid = -1;
for (int i = 0, n = centroids.length; i < n; i++)
{
double d = distance(centroids[i], x);
// if the d == minDistance then keep current selection
if (d < minDistance)
{
minDistance = d;
closestCentroid = i;
}
}
return closestCentroid;
}
然后将该点加入该重心所属的聚类,这会导致重心发生改变,于是重新计算重心,然后重复上一步,直到重心不再发生改变(当没有点可供加入的时候重心自然就不会变化了)。
public void cluster()
{
boolean centroidsChanged = true;
while (centroidsChanged == true)
{
// Create a set points for each cluster
// 创建每个聚类的数据点集合
List<Set<DataPoint>> clusters = new ArrayList<Set<DataPoint>>(k);
for (int i = 0; i < k; i++)
{
clusters.add(new HashSet<DataPoint>());
}
// Assign points to each set based on minimum distance from the centroids
// 根据到重心的最短距离为每个聚类分配数据点
for (DataPoint p : allDataPoints)
{
int i = findClosestCentroid(allCentroids, p);
clusters.get(i).add(p);
}
for (int i = 0; i < k; i++)
{
// 创建聚类
allClusters[i] = new Cluster(clusters.get(i));
}
// Calculate new cluster centroids, and
// check if any of the centroids has changed
centroidsChanged = false;
for (int i = 0; i < allClusters.length; i++)
{
if (clusters.get(i).size() > 0)
{
double[] newCentroidValues = findCentroid(allClusters[i]);
double[] oldCentroidValues = allCentroids[i].getNumericAttrValues();
if (!Arrays.equals(oldCentroidValues, newCentroidValues))
{
allCentroids[i] = new DataPoint(allCentroids[i].getLabel(), newCentroidValues);
centroidsChanged = true;
}
}
else
{
// keep mean unchanged if cluster has no elements.
}
}
}
}
码农场