176 lines
5.0 KiB
Python
176 lines
5.0 KiB
Python
|
"""
|
||
|
================================================================
|
||
|
Comparing different hierarchical linkage methods on toy datasets
|
||
|
================================================================
|
||
|
|
||
|
This example shows characteristics of different linkage
|
||
|
methods for hierarchical clustering on datasets that are
|
||
|
"interesting" but still in 2D.
|
||
|
|
||
|
The main observations to make are:
|
||
|
|
||
|
- single linkage is fast, and can perform well on
|
||
|
non-globular data, but it performs poorly in the
|
||
|
presence of noise.
|
||
|
- average and complete linkage perform well on
|
||
|
cleanly separated globular clusters, but have mixed
|
||
|
results otherwise.
|
||
|
- Ward is the most effective method for noisy data.
|
||
|
|
||
|
While these examples give some intuition about the
|
||
|
algorithms, this intuition might not apply to very high
|
||
|
dimensional data.
|
||
|
|
||
|
"""
|
||
|
|
||
|
import time
|
||
|
import warnings
|
||
|
from itertools import cycle, islice
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
|
||
|
from sklearn import cluster, datasets
|
||
|
from sklearn.preprocessing import StandardScaler
|
||
|
|
||
|
# %%
|
||
|
# Generate datasets. We choose the size big enough to see the scalability
|
||
|
# of the algorithms, but not too big to avoid too long running times
|
||
|
|
||
|
n_samples = 1500
|
||
|
noisy_circles = datasets.make_circles(
|
||
|
n_samples=n_samples, factor=0.5, noise=0.05, random_state=170
|
||
|
)
|
||
|
noisy_moons = datasets.make_moons(n_samples=n_samples, noise=0.05, random_state=170)
|
||
|
blobs = datasets.make_blobs(n_samples=n_samples, random_state=170)
|
||
|
rng = np.random.RandomState(170)
|
||
|
no_structure = rng.rand(n_samples, 2), None
|
||
|
|
||
|
# Anisotropicly distributed data
|
||
|
X, y = datasets.make_blobs(n_samples=n_samples, random_state=170)
|
||
|
transformation = [[0.6, -0.6], [-0.4, 0.8]]
|
||
|
X_aniso = np.dot(X, transformation)
|
||
|
aniso = (X_aniso, y)
|
||
|
|
||
|
# blobs with varied variances
|
||
|
varied = datasets.make_blobs(
|
||
|
n_samples=n_samples, cluster_std=[1.0, 2.5, 0.5], random_state=170
|
||
|
)
|
||
|
|
||
|
# %%
|
||
|
# Run the clustering and plot
|
||
|
|
||
|
# Set up cluster parameters
|
||
|
plt.figure(figsize=(9 * 1.3 + 2, 14.5))
|
||
|
plt.subplots_adjust(
|
||
|
left=0.02, right=0.98, bottom=0.001, top=0.96, wspace=0.05, hspace=0.01
|
||
|
)
|
||
|
|
||
|
plot_num = 1
|
||
|
|
||
|
default_base = {"n_neighbors": 10, "n_clusters": 3}
|
||
|
|
||
|
datasets = [
|
||
|
(noisy_circles, {"n_clusters": 2}),
|
||
|
(noisy_moons, {"n_clusters": 2}),
|
||
|
(varied, {"n_neighbors": 2}),
|
||
|
(aniso, {"n_neighbors": 2}),
|
||
|
(blobs, {}),
|
||
|
(no_structure, {}),
|
||
|
]
|
||
|
|
||
|
for i_dataset, (dataset, algo_params) in enumerate(datasets):
|
||
|
# update parameters with dataset-specific values
|
||
|
params = default_base.copy()
|
||
|
params.update(algo_params)
|
||
|
|
||
|
X, y = dataset
|
||
|
|
||
|
# normalize dataset for easier parameter selection
|
||
|
X = StandardScaler().fit_transform(X)
|
||
|
|
||
|
# ============
|
||
|
# Create cluster objects
|
||
|
# ============
|
||
|
ward = cluster.AgglomerativeClustering(
|
||
|
n_clusters=params["n_clusters"], linkage="ward"
|
||
|
)
|
||
|
complete = cluster.AgglomerativeClustering(
|
||
|
n_clusters=params["n_clusters"], linkage="complete"
|
||
|
)
|
||
|
average = cluster.AgglomerativeClustering(
|
||
|
n_clusters=params["n_clusters"], linkage="average"
|
||
|
)
|
||
|
single = cluster.AgglomerativeClustering(
|
||
|
n_clusters=params["n_clusters"], linkage="single"
|
||
|
)
|
||
|
|
||
|
clustering_algorithms = (
|
||
|
("Single Linkage", single),
|
||
|
("Average Linkage", average),
|
||
|
("Complete Linkage", complete),
|
||
|
("Ward Linkage", ward),
|
||
|
)
|
||
|
|
||
|
for name, algorithm in clustering_algorithms:
|
||
|
t0 = time.time()
|
||
|
|
||
|
# catch warnings related to kneighbors_graph
|
||
|
with warnings.catch_warnings():
|
||
|
warnings.filterwarnings(
|
||
|
"ignore",
|
||
|
message="the number of connected components of the "
|
||
|
+ "connectivity matrix is [0-9]{1,2}"
|
||
|
+ " > 1. Completing it to avoid stopping the tree early.",
|
||
|
category=UserWarning,
|
||
|
)
|
||
|
algorithm.fit(X)
|
||
|
|
||
|
t1 = time.time()
|
||
|
if hasattr(algorithm, "labels_"):
|
||
|
y_pred = algorithm.labels_.astype(int)
|
||
|
else:
|
||
|
y_pred = algorithm.predict(X)
|
||
|
|
||
|
plt.subplot(len(datasets), len(clustering_algorithms), plot_num)
|
||
|
if i_dataset == 0:
|
||
|
plt.title(name, size=18)
|
||
|
|
||
|
colors = np.array(
|
||
|
list(
|
||
|
islice(
|
||
|
cycle(
|
||
|
[
|
||
|
"#377eb8",
|
||
|
"#ff7f00",
|
||
|
"#4daf4a",
|
||
|
"#f781bf",
|
||
|
"#a65628",
|
||
|
"#984ea3",
|
||
|
"#999999",
|
||
|
"#e41a1c",
|
||
|
"#dede00",
|
||
|
]
|
||
|
),
|
||
|
int(max(y_pred) + 1),
|
||
|
)
|
||
|
)
|
||
|
)
|
||
|
plt.scatter(X[:, 0], X[:, 1], s=10, color=colors[y_pred])
|
||
|
|
||
|
plt.xlim(-2.5, 2.5)
|
||
|
plt.ylim(-2.5, 2.5)
|
||
|
plt.xticks(())
|
||
|
plt.yticks(())
|
||
|
plt.text(
|
||
|
0.99,
|
||
|
0.01,
|
||
|
("%.2fs" % (t1 - t0)).lstrip("0"),
|
||
|
transform=plt.gca().transAxes,
|
||
|
size=15,
|
||
|
horizontalalignment="right",
|
||
|
)
|
||
|
plot_num += 1
|
||
|
|
||
|
plt.show()
|