K-means 算法是一种经典的无监督学习算法,用于将数据自动分为 \(k\) 个簇,这里的 \(k\) 需要提前给定。

K-means 算法假设簇是凸的、大小相近,此时处理效果最好,但是并不能处理复杂形状的簇(如半月形),对维度高的稀疏数据(如文本)不太适用。

算法步骤

设数据集为 \(X = \{x_1, x_2, \dots, x_n\}\)\(x_i \in \mathbb{R}^d\),希望聚为 \(k\) 类,算法流程如下:

  • 初始化:随机选择 \(k\) 个数据点作为初始的簇中心。
  • 分配:将每个数据点分配到最近的簇中心,以形成 \(k\) 个簇。设当前中心为 \(\mu_1, \dots, \mu_k\),对于样本 \(x_i\),所属的簇为 \[ x_i \in C_s \,\, \text{where} \,\, s = \arg\min_{j=1,\dots,k} \| x_i - \mu_j \|^2 \]
  • 更新:重新计算每个簇的中心,记 \(C_s\) 是第 \(s\) 个簇内的数据点集合,那么中心 \(\mu_s\)\[ \mu_s = \frac{1}{|C_s|} \sum_{x_i \in C_s} x_i \]
  • 重复进行分配和更新,直到所有的簇中心不再变化或达到最大迭代次数。

Python 实现

下面是 K-means 算法的动画效果实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation


def generate_custom_data():
rng = np.random.default_rng(42)

theta = rng.uniform(0, 2 * np.pi, 100)
r = rng.normal(1, 0.3, 100)
x1 = np.stack([r * np.cos(theta), r * np.sin(theta)], axis=1) # 圆形 + 噪声
x2 = rng.normal(loc=[3, 3], scale=[1.2, 0.6], size=(60, 2)) # 拉伸椭圆 + 偏移
x3 = rng.normal(loc=[-2, 4], scale=0.6, size=(70, 2)) # 密集圆 + 偏移
x4 = rng.normal(loc=[3, 0], scale=0.3, size=(40, 2)) # 密集圆 + 偏移

return np.vstack([x1, x2, x3, x4])


X = generate_custom_data()


class KMeansAnimator:
def __init__(self, X, k, max_iters=100, tol=1e-4):
self.X = X
self.k = k
self.max_iters = max_iters
self.tol = tol
self.history = []

def fit(self):
n_samples, _ = self.X.shape
rng = np.random.default_rng()
centroids = self.X[rng.choice(n_samples, self.k, replace=False)]

for _ in range(self.max_iters):
distances = np.linalg.norm(self.X[:, np.newaxis] - centroids, axis=2)
labels = np.argmin(distances, axis=1)
self.history.append((labels.copy(), centroids.copy()))

new_centroids = np.array(
[
(
self.X[labels == j].mean(axis=0)
if len(self.X[labels == j]) > 0
else centroids[j]
)
for j in range(self.k)
]
)

if np.linalg.norm(new_centroids - centroids) < self.tol:
self.history.append((labels.copy(), new_centroids.copy()))
break
centroids = new_centroids

return self.history[-1]


k = 4 # 聚类个数
animator = KMeansAnimator(X, k)
labels, centroids = animator.fit()

fig, ax = plt.subplots()
scat = ax.scatter([], [], c=[], cmap="viridis", s=30)
cent_scat = ax.scatter([], [], c="red", s=100, marker="X")
ax.set_xlim(X[:, 0].min() - 1, X[:, 0].max() + 1)
ax.set_ylim(X[:, 1].min() - 1, X[:, 1].max() + 1)
ax.set_title("K-means Clustering - Animation")


def update(frame):
labels, centroids = animator.history[frame]
scat.set_offsets(X)
scat.set_array(labels.astype(float))
cent_scat.set_offsets(centroids)
ax.set_title(f"K-means Iteration {frame + 1}")
return scat, cent_scat


ani = FuncAnimation(
fig, update, frames=len(animator.history), interval=800, repeat=False
)
ani.save("kmeans_animation.gif")

运行结果如下