717 lines
27 KiB
Python
717 lines
27 KiB
Python
|
"""
|
||
|
================
|
||
|
Metadata Routing
|
||
|
================
|
||
|
|
||
|
.. currentmodule:: sklearn
|
||
|
|
||
|
This document shows how you can use the :ref:`metadata routing mechanism
|
||
|
<metadata_routing>` in scikit-learn to route metadata to the estimators,
|
||
|
scorers, and CV splitters consuming them.
|
||
|
|
||
|
To better understand the following document, we need to introduce two concepts:
|
||
|
routers and consumers. A router is an object which forwards some given data and
|
||
|
metadata to other objects. In most cases, a router is a :term:`meta-estimator`,
|
||
|
i.e. an estimator which takes another estimator as a parameter. A function such
|
||
|
as :func:`sklearn.model_selection.cross_validate` which takes an estimator as a
|
||
|
parameter and forwards data and metadata, is also a router.
|
||
|
|
||
|
A consumer, on the other hand, is an object which accepts and uses some given
|
||
|
metadata. For instance, an estimator taking into account ``sample_weight`` in
|
||
|
its :term:`fit` method is a consumer of ``sample_weight``.
|
||
|
|
||
|
It is possible for an object to be both a router and a consumer. For instance,
|
||
|
a meta-estimator may take into account ``sample_weight`` in certain
|
||
|
calculations, but it may also route it to the underlying estimator.
|
||
|
|
||
|
First a few imports and some random data for the rest of the script.
|
||
|
"""
|
||
|
|
||
|
# %%
|
||
|
|
||
|
import warnings
|
||
|
from pprint import pprint
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from sklearn import set_config
|
||
|
from sklearn.base import (
|
||
|
BaseEstimator,
|
||
|
ClassifierMixin,
|
||
|
MetaEstimatorMixin,
|
||
|
RegressorMixin,
|
||
|
TransformerMixin,
|
||
|
clone,
|
||
|
)
|
||
|
from sklearn.linear_model import LinearRegression
|
||
|
from sklearn.utils import metadata_routing
|
||
|
from sklearn.utils.metadata_routing import (
|
||
|
MetadataRouter,
|
||
|
MethodMapping,
|
||
|
get_routing_for_object,
|
||
|
process_routing,
|
||
|
)
|
||
|
from sklearn.utils.validation import check_is_fitted
|
||
|
|
||
|
n_samples, n_features = 100, 4
|
||
|
rng = np.random.RandomState(42)
|
||
|
X = rng.rand(n_samples, n_features)
|
||
|
y = rng.randint(0, 2, size=n_samples)
|
||
|
my_groups = rng.randint(0, 10, size=n_samples)
|
||
|
my_weights = rng.rand(n_samples)
|
||
|
my_other_weights = rng.rand(n_samples)
|
||
|
|
||
|
# %%
|
||
|
# Metadata routing is only available if explicitly enabled:
|
||
|
set_config(enable_metadata_routing=True)
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# This utility function is a dummy to check if a metadata is passed:
|
||
|
def check_metadata(obj, **kwargs):
|
||
|
for key, value in kwargs.items():
|
||
|
if value is not None:
|
||
|
print(
|
||
|
f"Received {key} of length = {len(value)} in {obj.__class__.__name__}."
|
||
|
)
|
||
|
else:
|
||
|
print(f"{key} is None in {obj.__class__.__name__}.")
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# A utility function to nicely print the routing information of an object:
|
||
|
def print_routing(obj):
|
||
|
pprint(obj.get_metadata_routing()._serialize())
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# Consuming Estimator
|
||
|
# -------------------
|
||
|
# Here we demonstrate how an estimator can expose the required API to support
|
||
|
# metadata routing as a consumer. Imagine a simple classifier accepting
|
||
|
# ``sample_weight`` as a metadata on its ``fit`` and ``groups`` in its
|
||
|
# ``predict`` method:
|
||
|
|
||
|
|
||
|
class ExampleClassifier(ClassifierMixin, BaseEstimator):
|
||
|
def fit(self, X, y, sample_weight=None):
|
||
|
check_metadata(self, sample_weight=sample_weight)
|
||
|
# all classifiers need to expose a classes_ attribute once they're fit.
|
||
|
self.classes_ = np.array([0, 1])
|
||
|
return self
|
||
|
|
||
|
def predict(self, X, groups=None):
|
||
|
check_metadata(self, groups=groups)
|
||
|
# return a constant value of 1, not a very smart classifier!
|
||
|
return np.ones(len(X))
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# The above estimator now has all it needs to consume metadata. This is
|
||
|
# accomplished by some magic done in :class:`~base.BaseEstimator`. There are
|
||
|
# now three methods exposed by the above class: ``set_fit_request``,
|
||
|
# ``set_predict_request``, and ``get_metadata_routing``. There is also a
|
||
|
# ``set_score_request`` for ``sample_weight`` which is present since
|
||
|
# :class:`~base.ClassifierMixin` implements a ``score`` method accepting
|
||
|
# ``sample_weight``. The same applies to regressors which inherit from
|
||
|
# :class:`~base.RegressorMixin`.
|
||
|
#
|
||
|
# By default, no metadata is requested, which we can see as:
|
||
|
|
||
|
print_routing(ExampleClassifier())
|
||
|
|
||
|
# %%
|
||
|
# The above output means that ``sample_weight`` and ``groups`` are not
|
||
|
# requested by `ExampleClassifier`, and if a router is given those metadata, it
|
||
|
# should raise an error, since the user has not explicitly set whether they are
|
||
|
# required or not. The same is true for ``sample_weight`` in the ``score``
|
||
|
# method, which is inherited from :class:`~base.ClassifierMixin`. In order to
|
||
|
# explicitly set request values for those metadata, we can use these methods:
|
||
|
|
||
|
est = (
|
||
|
ExampleClassifier()
|
||
|
.set_fit_request(sample_weight=False)
|
||
|
.set_predict_request(groups=True)
|
||
|
.set_score_request(sample_weight=False)
|
||
|
)
|
||
|
print_routing(est)
|
||
|
|
||
|
# %%
|
||
|
# .. note ::
|
||
|
# Please note that as long as the above estimator is not used in a
|
||
|
# meta-estimator, the user does not need to set any requests for the
|
||
|
# metadata and the set values are ignored, since a consumer does not
|
||
|
# validate or route given metadata. A simple usage of the above estimator
|
||
|
# would work as expected.
|
||
|
|
||
|
est = ExampleClassifier()
|
||
|
est.fit(X, y, sample_weight=my_weights)
|
||
|
est.predict(X[:3, :], groups=my_groups)
|
||
|
|
||
|
# %%
|
||
|
# Routing Meta-Estimator
|
||
|
# ----------------------
|
||
|
# Now, we show how to design a meta-estimator to be a router. As a simplified
|
||
|
# example, here is a meta-estimator, which doesn't do much other than routing
|
||
|
# the metadata.
|
||
|
|
||
|
|
||
|
class MetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
|
||
|
def __init__(self, estimator):
|
||
|
self.estimator = estimator
|
||
|
|
||
|
def get_metadata_routing(self):
|
||
|
# This method defines the routing for this meta-estimator.
|
||
|
# In order to do so, a `MetadataRouter` instance is created, and the
|
||
|
# routing is added to it. More explanations follow below.
|
||
|
router = MetadataRouter(owner=self.__class__.__name__).add(
|
||
|
estimator=self.estimator,
|
||
|
method_mapping=MethodMapping()
|
||
|
.add(caller="fit", callee="fit")
|
||
|
.add(caller="predict", callee="predict")
|
||
|
.add(caller="score", callee="score"),
|
||
|
)
|
||
|
return router
|
||
|
|
||
|
def fit(self, X, y, **fit_params):
|
||
|
# `get_routing_for_object` returns a copy of the `MetadataRouter`
|
||
|
# constructed by the above `get_metadata_routing` method, that is
|
||
|
# internally called.
|
||
|
request_router = get_routing_for_object(self)
|
||
|
# Meta-estimators are responsible for validating the given metadata.
|
||
|
# `method` refers to the parent's method, i.e. `fit` in this example.
|
||
|
request_router.validate_metadata(params=fit_params, method="fit")
|
||
|
# `MetadataRouter.route_params` maps the given metadata to the metadata
|
||
|
# required by the underlying estimator based on the routing information
|
||
|
# defined by the MetadataRouter. The output of type `Bunch` has a key
|
||
|
# for each consuming object and those hold keys for their consuming
|
||
|
# methods, which then contain key for the metadata which should be
|
||
|
# routed to them.
|
||
|
routed_params = request_router.route_params(params=fit_params, caller="fit")
|
||
|
|
||
|
# A sub-estimator is fitted and its classes are attributed to the
|
||
|
# meta-estimator.
|
||
|
self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
|
||
|
self.classes_ = self.estimator_.classes_
|
||
|
return self
|
||
|
|
||
|
def predict(self, X, **predict_params):
|
||
|
check_is_fitted(self)
|
||
|
# As in `fit`, we get a copy of the object's MetadataRouter,
|
||
|
request_router = get_routing_for_object(self)
|
||
|
# then we validate the given metadata,
|
||
|
request_router.validate_metadata(params=predict_params, method="predict")
|
||
|
# and then prepare the input to the underlying `predict` method.
|
||
|
routed_params = request_router.route_params(
|
||
|
params=predict_params, caller="predict"
|
||
|
)
|
||
|
return self.estimator_.predict(X, **routed_params.estimator.predict)
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# Let's break down different parts of the above code.
|
||
|
#
|
||
|
# First, the :meth:`~utils.metadata_routing.get_routing_for_object` takes our
|
||
|
# meta-estimator (``self``) and returns a
|
||
|
# :class:`~utils.metadata_routing.MetadataRouter` or, a
|
||
|
# :class:`~utils.metadata_routing.MetadataRequest` if the object is a consumer,
|
||
|
# based on the output of the estimator's ``get_metadata_routing`` method.
|
||
|
#
|
||
|
# Then in each method, we use the ``route_params`` method to construct a
|
||
|
# dictionary of the form ``{"object_name": {"method_name": {"metadata":
|
||
|
# value}}}`` to pass to the underlying estimator's method. The ``object_name``
|
||
|
# (``estimator`` in the above ``routed_params.estimator.fit`` example) is the
|
||
|
# same as the one added in the ``get_metadata_routing``. ``validate_metadata``
|
||
|
# makes sure all given metadata are requested to avoid silent bugs.
|
||
|
#
|
||
|
# Next, we illustrate the different behaviors and notably the type of errors
|
||
|
# raised.
|
||
|
|
||
|
meta_est = MetaClassifier(
|
||
|
estimator=ExampleClassifier().set_fit_request(sample_weight=True)
|
||
|
)
|
||
|
meta_est.fit(X, y, sample_weight=my_weights)
|
||
|
|
||
|
# %%
|
||
|
# Note that the above example is calling our utility function
|
||
|
# `check_metadata()` via the `ExampleClassifier`. It checks that
|
||
|
# ``sample_weight`` is correctly passed to it. If it is not, like in the
|
||
|
# following example, it would print that ``sample_weight`` is ``None``:
|
||
|
|
||
|
meta_est.fit(X, y)
|
||
|
|
||
|
# %%
|
||
|
# If we pass an unknown metadata, an error is raised:
|
||
|
try:
|
||
|
meta_est.fit(X, y, test=my_weights)
|
||
|
except TypeError as e:
|
||
|
print(e)
|
||
|
|
||
|
# %%
|
||
|
# And if we pass a metadata which is not explicitly requested:
|
||
|
try:
|
||
|
meta_est.fit(X, y, sample_weight=my_weights).predict(X, groups=my_groups)
|
||
|
except ValueError as e:
|
||
|
print(e)
|
||
|
|
||
|
# %%
|
||
|
# Also, if we explicitly set it as not requested, but it is provided:
|
||
|
meta_est = MetaClassifier(
|
||
|
estimator=ExampleClassifier()
|
||
|
.set_fit_request(sample_weight=True)
|
||
|
.set_predict_request(groups=False)
|
||
|
)
|
||
|
try:
|
||
|
meta_est.fit(X, y, sample_weight=my_weights).predict(X[:3, :], groups=my_groups)
|
||
|
except TypeError as e:
|
||
|
print(e)
|
||
|
|
||
|
# %%
|
||
|
# Another concept to introduce is **aliased metadata**. This is when an
|
||
|
# estimator requests a metadata with a different variable name than the default
|
||
|
# variable name. For instance, in a setting where there are two estimators in a
|
||
|
# pipeline, one could request ``sample_weight1`` and the other
|
||
|
# ``sample_weight2``. Note that this doesn't change what the estimator expects,
|
||
|
# it only tells the meta-estimator how to map the provided metadata to what is
|
||
|
# required. Here's an example, where we pass ``aliased_sample_weight`` to the
|
||
|
# meta-estimator, but the meta-estimator understands that
|
||
|
# ``aliased_sample_weight`` is an alias for ``sample_weight``, and passes it as
|
||
|
# ``sample_weight`` to the underlying estimator:
|
||
|
meta_est = MetaClassifier(
|
||
|
estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight")
|
||
|
)
|
||
|
meta_est.fit(X, y, aliased_sample_weight=my_weights)
|
||
|
|
||
|
# %%
|
||
|
# Passing ``sample_weight`` here will fail since it is requested with an
|
||
|
# alias and ``sample_weight`` with that name is not requested:
|
||
|
try:
|
||
|
meta_est.fit(X, y, sample_weight=my_weights)
|
||
|
except TypeError as e:
|
||
|
print(e)
|
||
|
|
||
|
# %%
|
||
|
# This leads us to the ``get_metadata_routing``. The way routing works in
|
||
|
# scikit-learn is that consumers request what they need, and routers pass that
|
||
|
# along. Additionally, a router exposes what it requires itself so that it can
|
||
|
# be used inside another router, e.g. a pipeline inside a grid search object.
|
||
|
# The output of the ``get_metadata_routing`` which is a dictionary
|
||
|
# representation of a :class:`~utils.metadata_routing.MetadataRouter`, includes
|
||
|
# the complete tree of requested metadata by all nested objects and their
|
||
|
# corresponding method routings, i.e. which method of a sub-estimator is used
|
||
|
# in which method of a meta-estimator:
|
||
|
|
||
|
print_routing(meta_est)
|
||
|
|
||
|
# %%
|
||
|
# As you can see, the only metadata requested for method ``fit`` is
|
||
|
# ``"sample_weight"`` with ``"aliased_sample_weight"`` as the alias. The
|
||
|
# ``~utils.metadata_routing.MetadataRouter`` class enables us to easily create
|
||
|
# the routing object which would create the output we need for our
|
||
|
# ``get_metadata_routing``.
|
||
|
#
|
||
|
# In order to understand how aliases work in meta-estimators, imagine our
|
||
|
# meta-estimator inside another one:
|
||
|
|
||
|
meta_meta_est = MetaClassifier(estimator=meta_est).fit(
|
||
|
X, y, aliased_sample_weight=my_weights
|
||
|
)
|
||
|
|
||
|
# %%
|
||
|
# In the above example, this is how the ``fit`` method of `meta_meta_est`
|
||
|
# will call their sub-estimator's ``fit`` methods::
|
||
|
#
|
||
|
# # user feeds `my_weights` as `aliased_sample_weight` into `meta_meta_est`:
|
||
|
# meta_meta_est.fit(X, y, aliased_sample_weight=my_weights):
|
||
|
# ...
|
||
|
#
|
||
|
# # the first sub-estimator (`meta_est`) expects `aliased_sample_weight`
|
||
|
# self.estimator_.fit(X, y, aliased_sample_weight=aliased_sample_weight):
|
||
|
# ...
|
||
|
#
|
||
|
# # the second sub-estimator (`est`) expects `sample_weight`
|
||
|
# self.estimator_.fit(X, y, sample_weight=aliased_sample_weight):
|
||
|
# ...
|
||
|
|
||
|
# %%
|
||
|
# Consuming and routing Meta-Estimator
|
||
|
# ------------------------------------
|
||
|
# For a slightly more complex example, consider a meta-estimator that routes
|
||
|
# metadata to an underlying estimator as before, but it also uses some metadata
|
||
|
# in its own methods. This meta-estimator is a consumer and a router at the
|
||
|
# same time. Implementing one is very similar to what we had before, but with a
|
||
|
# few tweaks.
|
||
|
|
||
|
|
||
|
class RouterConsumerClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
|
||
|
def __init__(self, estimator):
|
||
|
self.estimator = estimator
|
||
|
|
||
|
def get_metadata_routing(self):
|
||
|
router = (
|
||
|
MetadataRouter(owner=self.__class__.__name__)
|
||
|
# defining metadata routing request values for usage in the meta-estimator
|
||
|
.add_self_request(self)
|
||
|
# defining metadata routing request values for usage in the sub-estimator
|
||
|
.add(
|
||
|
estimator=self.estimator,
|
||
|
method_mapping=MethodMapping()
|
||
|
.add(caller="fit", callee="fit")
|
||
|
.add(caller="predict", callee="predict")
|
||
|
.add(caller="score", callee="score"),
|
||
|
)
|
||
|
)
|
||
|
return router
|
||
|
|
||
|
# Since `sample_weight` is used and consumed here, it should be defined as
|
||
|
# an explicit argument in the method's signature. All other metadata which
|
||
|
# are only routed, will be passed as `**fit_params`:
|
||
|
def fit(self, X, y, sample_weight, **fit_params):
|
||
|
if self.estimator is None:
|
||
|
raise ValueError("estimator cannot be None!")
|
||
|
|
||
|
check_metadata(self, sample_weight=sample_weight)
|
||
|
|
||
|
# We add `sample_weight` to the `fit_params` dictionary.
|
||
|
if sample_weight is not None:
|
||
|
fit_params["sample_weight"] = sample_weight
|
||
|
|
||
|
request_router = get_routing_for_object(self)
|
||
|
request_router.validate_metadata(params=fit_params, method="fit")
|
||
|
routed_params = request_router.route_params(params=fit_params, caller="fit")
|
||
|
self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
|
||
|
self.classes_ = self.estimator_.classes_
|
||
|
return self
|
||
|
|
||
|
def predict(self, X, **predict_params):
|
||
|
check_is_fitted(self)
|
||
|
# As in `fit`, we get a copy of the object's MetadataRouter,
|
||
|
request_router = get_routing_for_object(self)
|
||
|
# we validate the given metadata,
|
||
|
request_router.validate_metadata(params=predict_params, method="predict")
|
||
|
# and then prepare the input to the underlying ``predict`` method.
|
||
|
routed_params = request_router.route_params(
|
||
|
params=predict_params, caller="predict"
|
||
|
)
|
||
|
return self.estimator_.predict(X, **routed_params.estimator.predict)
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# The key parts where the above meta-estimator differs from our previous
|
||
|
# meta-estimator is accepting ``sample_weight`` explicitly in ``fit`` and
|
||
|
# including it in ``fit_params``. Since ``sample_weight`` is an explicit
|
||
|
# argument, we can be sure that ``set_fit_request(sample_weight=...)`` is
|
||
|
# present for this method. The meta-estimator is both a consumer, as well as a
|
||
|
# router of ``sample_weight``.
|
||
|
#
|
||
|
# In ``get_metadata_routing``, we add ``self`` to the routing using
|
||
|
# ``add_self_request`` to indicate this estimator is consuming
|
||
|
# ``sample_weight`` as well as being a router; which also adds a
|
||
|
# ``$self_request`` key to the routing info as illustrated below. Now let's
|
||
|
# look at some examples:
|
||
|
|
||
|
# %%
|
||
|
# - No metadata requested
|
||
|
meta_est = RouterConsumerClassifier(estimator=ExampleClassifier())
|
||
|
print_routing(meta_est)
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# - ``sample_weight`` requested by sub-estimator
|
||
|
meta_est = RouterConsumerClassifier(
|
||
|
estimator=ExampleClassifier().set_fit_request(sample_weight=True)
|
||
|
)
|
||
|
print_routing(meta_est)
|
||
|
|
||
|
# %%
|
||
|
# - ``sample_weight`` requested by meta-estimator
|
||
|
meta_est = RouterConsumerClassifier(estimator=ExampleClassifier()).set_fit_request(
|
||
|
sample_weight=True
|
||
|
)
|
||
|
print_routing(meta_est)
|
||
|
|
||
|
# %%
|
||
|
# Note the difference in the requested metadata representations above.
|
||
|
#
|
||
|
# - We can also alias the metadata to pass different values to the fit methods
|
||
|
# of the meta- and the sub-estimator:
|
||
|
|
||
|
meta_est = RouterConsumerClassifier(
|
||
|
estimator=ExampleClassifier().set_fit_request(sample_weight="clf_sample_weight"),
|
||
|
).set_fit_request(sample_weight="meta_clf_sample_weight")
|
||
|
print_routing(meta_est)
|
||
|
|
||
|
# %%
|
||
|
# However, ``fit`` of the meta-estimator only needs the alias for the
|
||
|
# sub-estimator and addresses their own sample weight as `sample_weight`, since
|
||
|
# it doesn't validate and route its own required metadata:
|
||
|
meta_est.fit(X, y, sample_weight=my_weights, clf_sample_weight=my_other_weights)
|
||
|
|
||
|
# %%
|
||
|
# - Alias only on the sub-estimator:
|
||
|
#
|
||
|
# This is useful when we don't want the meta-estimator to use the metadata, but
|
||
|
# the sub-estimator should.
|
||
|
meta_est = RouterConsumerClassifier(
|
||
|
estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight")
|
||
|
)
|
||
|
print_routing(meta_est)
|
||
|
# %%
|
||
|
# The meta-estimator cannot use `aliased_sample_weight`, because it expects
|
||
|
# it passed as `sample_weight`. This would apply even if
|
||
|
# `set_fit_request(sample_weight=True)` was set on it.
|
||
|
|
||
|
# %%
|
||
|
# Simple Pipeline
|
||
|
# ---------------
|
||
|
# A slightly more complicated use-case is a meta-estimator resembling a
|
||
|
# :class:`~pipeline.Pipeline`. Here is a meta-estimator, which accepts a
|
||
|
# transformer and a classifier. When calling its `fit` method, it applies the
|
||
|
# transformer's `fit` and `transform` before running the classifier on the
|
||
|
# transformed data. Upon `predict`, it applies the transformer's `transform`
|
||
|
# before predicting with the classifier's `predict` method on the transformed
|
||
|
# new data.
|
||
|
|
||
|
|
||
|
class SimplePipeline(ClassifierMixin, BaseEstimator):
|
||
|
def __init__(self, transformer, classifier):
|
||
|
self.transformer = transformer
|
||
|
self.classifier = classifier
|
||
|
|
||
|
def get_metadata_routing(self):
|
||
|
router = (
|
||
|
MetadataRouter(owner=self.__class__.__name__)
|
||
|
# We add the routing for the transformer.
|
||
|
.add(
|
||
|
transformer=self.transformer,
|
||
|
method_mapping=MethodMapping()
|
||
|
# The metadata is routed such that it retraces how
|
||
|
# `SimplePipeline` internally calls the transformer's `fit` and
|
||
|
# `transform` methods in its own methods (`fit` and `predict`).
|
||
|
.add(caller="fit", callee="fit")
|
||
|
.add(caller="fit", callee="transform")
|
||
|
.add(caller="predict", callee="transform"),
|
||
|
)
|
||
|
# We add the routing for the classifier.
|
||
|
.add(
|
||
|
classifier=self.classifier,
|
||
|
method_mapping=MethodMapping()
|
||
|
.add(caller="fit", callee="fit")
|
||
|
.add(caller="predict", callee="predict"),
|
||
|
)
|
||
|
)
|
||
|
return router
|
||
|
|
||
|
def fit(self, X, y, **fit_params):
|
||
|
routed_params = process_routing(self, "fit", **fit_params)
|
||
|
|
||
|
self.transformer_ = clone(self.transformer).fit(
|
||
|
X, y, **routed_params.transformer.fit
|
||
|
)
|
||
|
X_transformed = self.transformer_.transform(
|
||
|
X, **routed_params.transformer.transform
|
||
|
)
|
||
|
|
||
|
self.classifier_ = clone(self.classifier).fit(
|
||
|
X_transformed, y, **routed_params.classifier.fit
|
||
|
)
|
||
|
return self
|
||
|
|
||
|
def predict(self, X, **predict_params):
|
||
|
routed_params = process_routing(self, "predict", **predict_params)
|
||
|
|
||
|
X_transformed = self.transformer_.transform(
|
||
|
X, **routed_params.transformer.transform
|
||
|
)
|
||
|
return self.classifier_.predict(
|
||
|
X_transformed, **routed_params.classifier.predict
|
||
|
)
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# Note the usage of :class:`~utils.metadata_routing.MethodMapping` to
|
||
|
# declare which methods of the child estimator (callee) are used in which
|
||
|
# methods of the meta estimator (caller). As you can see, `SimplePipeline` uses
|
||
|
# the transformer's ``transform`` and ``fit`` methods in ``fit``, and its
|
||
|
# ``transform`` method in ``predict``, and that's what you see implemented in
|
||
|
# the routing structure of the pipeline class.
|
||
|
#
|
||
|
# Another difference in the above example with the previous ones is the usage
|
||
|
# of :func:`~utils.metadata_routing.process_routing`, which processes the input
|
||
|
# parameters, does the required validation, and returns the `routed_params`
|
||
|
# which we had created in previous examples. This reduces the boilerplate code
|
||
|
# a developer needs to write in each meta-estimator's method. Developers are
|
||
|
# strongly recommended to use this function unless there is a good reason
|
||
|
# against it.
|
||
|
#
|
||
|
# In order to test the above pipeline, let's add an example transformer.
|
||
|
|
||
|
|
||
|
class ExampleTransformer(TransformerMixin, BaseEstimator):
|
||
|
def fit(self, X, y, sample_weight=None):
|
||
|
check_metadata(self, sample_weight=sample_weight)
|
||
|
return self
|
||
|
|
||
|
def transform(self, X, groups=None):
|
||
|
check_metadata(self, groups=groups)
|
||
|
return X
|
||
|
|
||
|
def fit_transform(self, X, y, sample_weight=None, groups=None):
|
||
|
return self.fit(X, y, sample_weight).transform(X, groups)
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# Note that in the above example, we have implemented ``fit_transform`` which
|
||
|
# calls ``fit`` and ``transform`` with the appropriate metadata. This is only
|
||
|
# required if ``transform`` accepts metadata, since the default ``fit_transform``
|
||
|
# implementation in :class:`~base.TransformerMixin` doesn't pass metadata to
|
||
|
# ``transform``.
|
||
|
#
|
||
|
# Now we can test our pipeline, and see if metadata is correctly passed around.
|
||
|
# This example uses our `SimplePipeline`, our `ExampleTransformer`, and our
|
||
|
# `RouterConsumerClassifier` which uses our `ExampleClassifier`.
|
||
|
|
||
|
pipe = SimplePipeline(
|
||
|
transformer=ExampleTransformer()
|
||
|
# we set transformer's fit to receive sample_weight
|
||
|
.set_fit_request(sample_weight=True)
|
||
|
# we set transformer's transform to receive groups
|
||
|
.set_transform_request(groups=True),
|
||
|
classifier=RouterConsumerClassifier(
|
||
|
estimator=ExampleClassifier()
|
||
|
# we want this sub-estimator to receive sample_weight in fit
|
||
|
.set_fit_request(sample_weight=True)
|
||
|
# but not groups in predict
|
||
|
.set_predict_request(groups=False),
|
||
|
)
|
||
|
# and we want the meta-estimator to receive sample_weight as well
|
||
|
.set_fit_request(sample_weight=True),
|
||
|
)
|
||
|
pipe.fit(X, y, sample_weight=my_weights, groups=my_groups).predict(
|
||
|
X[:3], groups=my_groups
|
||
|
)
|
||
|
|
||
|
# %%
|
||
|
# Deprecation / Default Value Change
|
||
|
# ----------------------------------
|
||
|
# In this section we show how one should handle the case where a router becomes
|
||
|
# also a consumer, especially when it consumes the same metadata as its
|
||
|
# sub-estimator, or a consumer starts consuming a metadata which it wasn't in
|
||
|
# an older release. In this case, a warning should be raised for a while, to
|
||
|
# let users know the behavior is changed from previous versions.
|
||
|
|
||
|
|
||
|
class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
|
||
|
def __init__(self, estimator):
|
||
|
self.estimator = estimator
|
||
|
|
||
|
def fit(self, X, y, **fit_params):
|
||
|
routed_params = process_routing(self, "fit", **fit_params)
|
||
|
self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
|
||
|
|
||
|
def get_metadata_routing(self):
|
||
|
router = MetadataRouter(owner=self.__class__.__name__).add(
|
||
|
estimator=self.estimator,
|
||
|
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
|
||
|
)
|
||
|
return router
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# As explained above, this is a valid usage if `my_weights` aren't supposed
|
||
|
# to be passed as `sample_weight` to `MetaRegressor`:
|
||
|
|
||
|
reg = MetaRegressor(estimator=LinearRegression().set_fit_request(sample_weight=True))
|
||
|
reg.fit(X, y, sample_weight=my_weights)
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# Now imagine we further develop ``MetaRegressor`` and it now also *consumes*
|
||
|
# ``sample_weight``:
|
||
|
|
||
|
|
||
|
class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
|
||
|
# show warning to remind user to explicitly set the value with
|
||
|
# `.set_{method}_request(sample_weight={boolean})`
|
||
|
__metadata_request__fit = {"sample_weight": metadata_routing.WARN}
|
||
|
|
||
|
def __init__(self, estimator):
|
||
|
self.estimator = estimator
|
||
|
|
||
|
def fit(self, X, y, sample_weight=None, **fit_params):
|
||
|
routed_params = process_routing(
|
||
|
self, "fit", sample_weight=sample_weight, **fit_params
|
||
|
)
|
||
|
check_metadata(self, sample_weight=sample_weight)
|
||
|
self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
|
||
|
|
||
|
def get_metadata_routing(self):
|
||
|
router = (
|
||
|
MetadataRouter(owner=self.__class__.__name__)
|
||
|
.add_self_request(self)
|
||
|
.add(
|
||
|
estimator=self.estimator,
|
||
|
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
|
||
|
)
|
||
|
)
|
||
|
return router
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# The above implementation is almost the same as ``MetaRegressor``, and
|
||
|
# because of the default request value defined in ``__metadata_request__fit``
|
||
|
# there is a warning raised when fitted.
|
||
|
|
||
|
with warnings.catch_warnings(record=True) as record:
|
||
|
WeightedMetaRegressor(
|
||
|
estimator=LinearRegression().set_fit_request(sample_weight=False)
|
||
|
).fit(X, y, sample_weight=my_weights)
|
||
|
for w in record:
|
||
|
print(w.message)
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# When an estimator consumes a metadata which it didn't consume before, the
|
||
|
# following pattern can be used to warn the users about it.
|
||
|
|
||
|
|
||
|
class ExampleRegressor(RegressorMixin, BaseEstimator):
|
||
|
__metadata_request__fit = {"sample_weight": metadata_routing.WARN}
|
||
|
|
||
|
def fit(self, X, y, sample_weight=None):
|
||
|
check_metadata(self, sample_weight=sample_weight)
|
||
|
return self
|
||
|
|
||
|
def predict(self, X):
|
||
|
return np.zeros(shape=(len(X)))
|
||
|
|
||
|
|
||
|
with warnings.catch_warnings(record=True) as record:
|
||
|
MetaRegressor(estimator=ExampleRegressor()).fit(X, y, sample_weight=my_weights)
|
||
|
for w in record:
|
||
|
print(w.message)
|
||
|
|
||
|
# %%
|
||
|
# At the end we disable the configuration flag for metadata routing:
|
||
|
|
||
|
set_config(enable_metadata_routing=False)
|
||
|
|
||
|
# %%
|
||
|
# Third Party Development and scikit-learn Dependency
|
||
|
# ---------------------------------------------------
|
||
|
#
|
||
|
# As seen above, information is communicated between classes using
|
||
|
# :class:`~utils.metadata_routing.MetadataRequest` and
|
||
|
# :class:`~utils.metadata_routing.MetadataRouter`. It is strongly not advised,
|
||
|
# but possible to vendor the tools related to metadata-routing if you strictly
|
||
|
# want to have a scikit-learn compatible estimator, without depending on the
|
||
|
# scikit-learn package. If all of the following conditions are met, you do NOT
|
||
|
# need to modify your code at all:
|
||
|
#
|
||
|
# - your estimator inherits from :class:`~base.BaseEstimator`
|
||
|
# - the parameters consumed by your estimator's methods, e.g. ``fit``, are
|
||
|
# explicitly defined in the method's signature, as opposed to being
|
||
|
# ``*args`` or ``*kwargs``.
|
||
|
# - your estimator does not route any metadata to the underlying objects, i.e.
|
||
|
# it's not a *router*.
|