Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d077f82

Browse filesBrowse files
ENH: Display parameters in HTML representation (scikit-learn#30763)
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
1 parent 19a6e61 commit d077f82
Copy full SHA for d077f82

File tree

21 files changed

+595
-181
lines changed
Filter options

21 files changed

+595
-181
lines changed
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- :class:`base.BaseEstimator` now has a parameter table added to the
2+
estimators HTML representation that can be visualized with jupyter.
3+
By :user:`Guillaume Lemaitre <glemaitre>` and
4+
:user:`Dea María Léon <DeaMariaLeon>`

‎sklearn/base.py

Copy file name to clipboardExpand all lines: sklearn/base.py
+65-32Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
from . import __version__
1717
from ._config import config_context, get_config
1818
from .exceptions import InconsistentVersionWarning
19-
from .utils._estimator_html_repr import _HTMLDocumentationLinkMixin, estimator_html_repr
2019
from .utils._metadata_requests import _MetadataRequester, _routing_enabled
20+
from .utils._missing import is_scalar_nan
2121
from .utils._param_validation import validate_parameter_constraints
22+
from .utils._repr_html.base import ReprHTMLMixin, _HTMLDocumentationLinkMixin
23+
from .utils._repr_html.estimator import estimator_html_repr
24+
from .utils._repr_html.params import ParamsDict
2225
from .utils._set_output import _SetOutputMixin
2326
from .utils._tags import (
2427
ClassifierTags,
@@ -150,7 +153,7 @@ def _clone_parametrized(estimator, *, safe=True):
150153
return new_object
151154

152155

153-
class BaseEstimator(_HTMLDocumentationLinkMixin, _MetadataRequester):
156+
class BaseEstimator(ReprHTMLMixin, _HTMLDocumentationLinkMixin, _MetadataRequester):
154157
"""Base class for all estimators in scikit-learn.
155158
156159
Inheriting from this class provides default implementations of:
@@ -194,6 +197,8 @@ class BaseEstimator(_HTMLDocumentationLinkMixin, _MetadataRequester):
194197
array([3, 3, 3])
195198
"""
196199

200+
_html_repr = estimator_html_repr
201+
197202
@classmethod
198203
def _get_param_names(cls):
199204
"""Get parameter names for the estimator"""
@@ -249,6 +254,64 @@ def get_params(self, deep=True):
249254
out[key] = value
250255
return out
251256

257+
def _get_params_html(self, deep=True):
258+
"""
259+
Get parameters for this estimator with a specific HTML representation.
260+
261+
Parameters
262+
----------
263+
deep : bool, default=True
264+
If True, will return the parameters for this estimator and
265+
contained subobjects that are estimators.
266+
267+
Returns
268+
-------
269+
params : ParamsDict
270+
Parameter names mapped to their values. We return a `ParamsDict`
271+
dictionary, which renders a specific HTML representation in table
272+
form.
273+
"""
274+
out = self.get_params(deep=deep)
275+
276+
init_func = getattr(self.__init__, "deprecated_original", self.__init__)
277+
init_default_params = inspect.signature(init_func).parameters
278+
init_default_params = {
279+
name: param.default for name, param in init_default_params.items()
280+
}
281+
282+
def is_non_default(param_name, param_value):
283+
"""Finds the parameters that have been set by the user."""
284+
if param_name not in init_default_params:
285+
# happens if k is part of a **kwargs
286+
return True
287+
if init_default_params[param_name] == inspect._empty:
288+
# k has no default value
289+
return True
290+
# avoid calling repr on nested estimators
291+
if isinstance(param_value, BaseEstimator) and type(param_value) is not type(
292+
init_default_params[param_name]
293+
):
294+
return True
295+
296+
if param_value != init_default_params[param_name] and not (
297+
is_scalar_nan(init_default_params[param_name])
298+
and is_scalar_nan(param_value)
299+
):
300+
return True
301+
return False
302+
303+
# reorder the parameters from `self.get_params` using the `__init__`
304+
# signature
305+
remaining_params = [name for name in out if name not in init_default_params]
306+
ordered_out = {name: out[name] for name in init_default_params if name in out}
307+
ordered_out.update({name: out[name] for name in remaining_params})
308+
309+
non_default_ls = tuple(
310+
[name for name, value in ordered_out.items() if is_non_default(name, value)]
311+
)
312+
313+
return ParamsDict(ordered_out, non_default=non_default_ls)
314+
252315
def set_params(self, **params):
253316
"""Set the parameters of this estimator.
254317
@@ -409,36 +472,6 @@ class attribute, which is a dictionary `param_name: list of constraints`. See
409472
caller_name=self.__class__.__name__,
410473
)
411474

412-
@property
413-
def _repr_html_(self):
414-
"""HTML representation of estimator.
415-
416-
This is redundant with the logic of `_repr_mimebundle_`. The latter
417-
should be favored in the long term, `_repr_html_` is only
418-
implemented for consumers who do not interpret `_repr_mimbundle_`.
419-
"""
420-
if get_config()["display"] != "diagram":
421-
raise AttributeError(
422-
"_repr_html_ is only defined when the "
423-
"'display' configuration option is set to "
424-
"'diagram'"
425-
)
426-
return self._repr_html_inner
427-
428-
def _repr_html_inner(self):
429-
"""This function is returned by the @property `_repr_html_` to make
430-
`hasattr(estimator, "_repr_html_") return `True` or `False` depending
431-
on `get_config()["display"]`.
432-
"""
433-
return estimator_html_repr(self)
434-
435-
def _repr_mimebundle_(self, **kwargs):
436-
"""Mime bundle used by jupyter kernels to display estimator"""
437-
output = {"text/plain": repr(self)}
438-
if get_config()["display"] == "diagram":
439-
output["text/html"] = estimator_html_repr(self)
440-
return output
441-
442475

443476
class ClassifierMixin:
444477
"""Mixin class for all classifiers in scikit-learn.

‎sklearn/compose/_column_transformer.py

Copy file name to clipboardExpand all lines: sklearn/compose/_column_transformer.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
from ..pipeline import _fit_transform_one, _name_estimators, _transform_one
2121
from ..preprocessing import FunctionTransformer
2222
from ..utils import Bunch
23-
from ..utils._estimator_html_repr import _VisualBlock
2423
from ..utils._indexing import _determine_key_type, _get_column_indices, _safe_indexing
2524
from ..utils._metadata_requests import METHODS
2625
from ..utils._param_validation import HasMethods, Hidden, Interval, StrOptions
26+
from ..utils._repr_html.estimator import _VisualBlock
2727
from ..utils._set_output import (
2828
_get_container_adapter,
2929
_get_output_config,

‎sklearn/ensemble/_stacking.py

Copy file name to clipboardExpand all lines: sklearn/ensemble/_stacking.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
from ..model_selection import check_cv, cross_val_predict
2525
from ..preprocessing import LabelEncoder
2626
from ..utils import Bunch
27-
from ..utils._estimator_html_repr import _VisualBlock
2827
from ..utils._param_validation import HasMethods, StrOptions
28+
from ..utils._repr_html.estimator import _VisualBlock
2929
from ..utils.metadata_routing import (
3030
MetadataRouter,
3131
MethodMapping,

‎sklearn/ensemble/_voting.py

Copy file name to clipboardExpand all lines: sklearn/ensemble/_voting.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
from ..exceptions import NotFittedError
2525
from ..preprocessing import LabelEncoder
2626
from ..utils import Bunch
27-
from ..utils._estimator_html_repr import _VisualBlock
2827
from ..utils._param_validation import StrOptions
28+
from ..utils._repr_html.estimator import _VisualBlock
2929
from ..utils.metadata_routing import (
3030
MetadataRouter,
3131
MethodMapping,

‎sklearn/model_selection/_search.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/_search.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
get_scorer_names,
3232
)
3333
from ..utils import Bunch, check_random_state
34-
from ..utils._estimator_html_repr import _VisualBlock
3534
from ..utils._param_validation import HasMethods, Interval, StrOptions
35+
from ..utils._repr_html.estimator import _VisualBlock
3636
from ..utils._tags import get_tags
3737
from ..utils.metadata_routing import (
3838
MetadataRouter,

‎sklearn/model_selection/tests/test_search.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/tests/test_search.py
+4-4Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2662,21 +2662,21 @@ def test_search_html_repr():
26622662
search_cv = GridSearchCV(pipeline, param_grid=param_grid, refit=False)
26632663
with config_context(display="diagram"):
26642664
repr_html = search_cv._repr_html_()
2665-
assert "<pre>DummyClassifier()</pre>" in repr_html
2665+
assert "<div>DummyClassifier</div>" in repr_html
26662666

26672667
# Fitted with `refit=False` shows the original pipeline
26682668
search_cv.fit(X, y)
26692669
with config_context(display="diagram"):
26702670
repr_html = search_cv._repr_html_()
2671-
assert "<pre>DummyClassifier()</pre>" in repr_html
2671+
assert "<div>DummyClassifier</div>" in repr_html
26722672

26732673
# Fitted with `refit=True` shows the best estimator
26742674
search_cv = GridSearchCV(pipeline, param_grid=param_grid, refit=True)
26752675
search_cv.fit(X, y)
26762676
with config_context(display="diagram"):
26772677
repr_html = search_cv._repr_html_()
2678-
assert "<pre>DummyClassifier()</pre>" not in repr_html
2679-
assert "<pre>LogisticRegression()</pre>" in repr_html
2678+
assert "<div>DummyClassifier</div>" not in repr_html
2679+
assert "<div>LogisticRegression</div>" in repr_html
26802680

26812681

26822682
# Metadata Routing Tests

‎sklearn/pipeline.py

Copy file name to clipboardExpand all lines: sklearn/pipeline.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from .exceptions import NotFittedError
1717
from .preprocessing import FunctionTransformer
1818
from .utils import Bunch
19-
from .utils._estimator_html_repr import _VisualBlock
2019
from .utils._metadata_requests import METHODS
2120
from .utils._param_validation import HasMethods, Hidden
21+
from .utils._repr_html.estimator import _VisualBlock
2222
from .utils._set_output import (
2323
_get_container_adapter,
2424
_safe_set_output,

‎sklearn/preprocessing/_function_transformer.py

Copy file name to clipboardExpand all lines: sklearn/preprocessing/_function_transformer.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import numpy as np
88

99
from ..base import BaseEstimator, TransformerMixin, _fit_context
10-
from ..utils._estimator_html_repr import _VisualBlock
1110
from ..utils._param_validation import StrOptions
11+
from ..utils._repr_html.estimator import _VisualBlock
1212
from ..utils._set_output import (
1313
_get_adapter_from_container,
1414
_get_output_config,

‎sklearn/tests/test_base.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_base.py
+8Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,3 +992,11 @@ def predict(self, X, prop=None):
992992
with warnings.catch_warnings(record=True) as record:
993993
CustomOutlierDetector().set_predict_request(prop=True).fit_predict([[1]], [1])
994994
assert len(record) == 0
995+
996+
997+
def test_get_params_html():
998+
"""Check the behaviour of the `_get_params_html` method."""
999+
est = MyEstimator(empty="test")
1000+
1001+
assert est._get_params_html() == {"l1": 0, "empty": "test"}
1002+
assert est._get_params_html().non_default == ("empty",)

‎sklearn/utils/__init__.py

Copy file name to clipboardExpand all lines: sklearn/utils/__init__.py
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from . import metadata_routing
88
from ._bunch import Bunch
99
from ._chunking import gen_batches, gen_even_slices
10-
from ._estimator_html_repr import estimator_html_repr
1110

1211
# Make _safe_indexing importable from here for backward compat as this particular
1312
# helper is considered semi-private and typically very useful for third-party
@@ -20,6 +19,8 @@
2019
shuffle,
2120
)
2221
from ._mask import safe_mask
22+
from ._repr_html.base import _HTMLDocumentationLinkMixin # noqa: F401
23+
from ._repr_html.estimator import estimator_html_repr
2324
from ._tags import (
2425
ClassifierTags,
2526
InputTags,

‎sklearn/utils/_repr_html/__init__.py

Copy file name to clipboard
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Authors: The scikit-learn developers
2+
# SPDX-License-Identifier: BSD-3-Clause

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.