130 lines
4.1 KiB
Python
130 lines
4.1 KiB
Python
|
"""
|
||
|
Comparison between grid search and successive halving
|
||
|
=====================================================
|
||
|
|
||
|
This example compares the parameter search performed by
|
||
|
:class:`~sklearn.model_selection.HalvingGridSearchCV` and
|
||
|
:class:`~sklearn.model_selection.GridSearchCV`.
|
||
|
|
||
|
"""
|
||
|
|
||
|
from time import time
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
|
||
|
from sklearn import datasets
|
||
|
from sklearn.experimental import enable_halving_search_cv # noqa
|
||
|
from sklearn.model_selection import GridSearchCV, HalvingGridSearchCV
|
||
|
from sklearn.svm import SVC
|
||
|
|
||
|
# %%
|
||
|
# We first define the parameter space for an :class:`~sklearn.svm.SVC`
|
||
|
# estimator, and compute the time required to train a
|
||
|
# :class:`~sklearn.model_selection.HalvingGridSearchCV` instance, as well as a
|
||
|
# :class:`~sklearn.model_selection.GridSearchCV` instance.
|
||
|
|
||
|
rng = np.random.RandomState(0)
|
||
|
X, y = datasets.make_classification(n_samples=1000, random_state=rng)
|
||
|
|
||
|
gammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
|
||
|
Cs = [1, 10, 100, 1e3, 1e4, 1e5]
|
||
|
param_grid = {"gamma": gammas, "C": Cs}
|
||
|
|
||
|
clf = SVC(random_state=rng)
|
||
|
|
||
|
tic = time()
|
||
|
gsh = HalvingGridSearchCV(
|
||
|
estimator=clf, param_grid=param_grid, factor=2, random_state=rng
|
||
|
)
|
||
|
gsh.fit(X, y)
|
||
|
gsh_time = time() - tic
|
||
|
|
||
|
tic = time()
|
||
|
gs = GridSearchCV(estimator=clf, param_grid=param_grid)
|
||
|
gs.fit(X, y)
|
||
|
gs_time = time() - tic
|
||
|
|
||
|
# %%
|
||
|
# We now plot heatmaps for both search estimators.
|
||
|
|
||
|
|
||
|
def make_heatmap(ax, gs, is_sh=False, make_cbar=False):
|
||
|
"""Helper to make a heatmap."""
|
||
|
results = pd.DataFrame(gs.cv_results_)
|
||
|
results[["param_C", "param_gamma"]] = results[["param_C", "param_gamma"]].astype(
|
||
|
np.float64
|
||
|
)
|
||
|
if is_sh:
|
||
|
# SH dataframe: get mean_test_score values for the highest iter
|
||
|
scores_matrix = results.sort_values("iter").pivot_table(
|
||
|
index="param_gamma",
|
||
|
columns="param_C",
|
||
|
values="mean_test_score",
|
||
|
aggfunc="last",
|
||
|
)
|
||
|
else:
|
||
|
scores_matrix = results.pivot(
|
||
|
index="param_gamma", columns="param_C", values="mean_test_score"
|
||
|
)
|
||
|
|
||
|
im = ax.imshow(scores_matrix)
|
||
|
|
||
|
ax.set_xticks(np.arange(len(Cs)))
|
||
|
ax.set_xticklabels(["{:.0E}".format(x) for x in Cs])
|
||
|
ax.set_xlabel("C", fontsize=15)
|
||
|
|
||
|
ax.set_yticks(np.arange(len(gammas)))
|
||
|
ax.set_yticklabels(["{:.0E}".format(x) for x in gammas])
|
||
|
ax.set_ylabel("gamma", fontsize=15)
|
||
|
|
||
|
# Rotate the tick labels and set their alignment.
|
||
|
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
||
|
|
||
|
if is_sh:
|
||
|
iterations = results.pivot_table(
|
||
|
index="param_gamma", columns="param_C", values="iter", aggfunc="max"
|
||
|
).values
|
||
|
for i in range(len(gammas)):
|
||
|
for j in range(len(Cs)):
|
||
|
ax.text(
|
||
|
j,
|
||
|
i,
|
||
|
iterations[i, j],
|
||
|
ha="center",
|
||
|
va="center",
|
||
|
color="w",
|
||
|
fontsize=20,
|
||
|
)
|
||
|
|
||
|
if make_cbar:
|
||
|
fig.subplots_adjust(right=0.8)
|
||
|
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
|
||
|
fig.colorbar(im, cax=cbar_ax)
|
||
|
cbar_ax.set_ylabel("mean_test_score", rotation=-90, va="bottom", fontsize=15)
|
||
|
|
||
|
|
||
|
fig, axes = plt.subplots(ncols=2, sharey=True)
|
||
|
ax1, ax2 = axes
|
||
|
|
||
|
make_heatmap(ax1, gsh, is_sh=True)
|
||
|
make_heatmap(ax2, gs, make_cbar=True)
|
||
|
|
||
|
ax1.set_title("Successive Halving\ntime = {:.3f}s".format(gsh_time), fontsize=15)
|
||
|
ax2.set_title("GridSearch\ntime = {:.3f}s".format(gs_time), fontsize=15)
|
||
|
|
||
|
plt.show()
|
||
|
|
||
|
# %%
|
||
|
# The heatmaps show the mean test score of the parameter combinations for an
|
||
|
# :class:`~sklearn.svm.SVC` instance. The
|
||
|
# :class:`~sklearn.model_selection.HalvingGridSearchCV` also shows the
|
||
|
# iteration at which the combinations where last used. The combinations marked
|
||
|
# as ``0`` were only evaluated at the first iteration, while the ones with
|
||
|
# ``5`` are the parameter combinations that are considered the best ones.
|
||
|
#
|
||
|
# We can see that the :class:`~sklearn.model_selection.HalvingGridSearchCV`
|
||
|
# class is able to find parameter combinations that are just as accurate as
|
||
|
# :class:`~sklearn.model_selection.GridSearchCV`, in much less time.
|