Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e5dac1d

Browse filesBrowse files
MAINT Parameter validation for metrics.cluster.adjusted_mutual_info_score (#25898)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 940303c commit e5dac1d
Copy full SHA for e5dac1d

File tree

Expand file treeCollapse file tree

2 files changed

+12
-5
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+12
-5
lines changed

‎sklearn/metrics/cluster/_supervised.py

Copy file name to clipboardExpand all lines: sklearn/metrics/cluster/_supervised.py
+11-5Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ...utils.multiclass import type_of_target
2828
from ...utils.validation import check_array, check_consistent_length
2929
from ...utils._param_validation import validate_params
30-
from ...utils._param_validation import Interval
30+
from ...utils._param_validation import Interval, StrOptions
3131

3232

3333
def check_clusterings(labels_true, labels_pred):
@@ -847,6 +847,13 @@ def mutual_info_score(labels_true, labels_pred, *, contingency=None):
847847
return np.clip(mi.sum(), 0.0, None)
848848

849849

850+
@validate_params(
851+
{
852+
"labels_true": ["array-like"],
853+
"labels_pred": ["array-like"],
854+
"average_method": [StrOptions({"arithmetic", "max", "min", "geometric"})],
855+
}
856+
)
850857
def adjusted_mutual_info_score(
851858
labels_true, labels_pred, *, average_method="arithmetic"
852859
):
@@ -876,17 +883,16 @@ def adjusted_mutual_info_score(
876883
877884
Parameters
878885
----------
879-
labels_true : int array, shape = [n_samples]
886+
labels_true : int array-like of shape (n_samples,)
880887
A clustering of the data into disjoint subsets, called :math:`U` in
881888
the above formula.
882889
883890
labels_pred : int array-like of shape (n_samples,)
884891
A clustering of the data into disjoint subsets, called :math:`V` in
885892
the above formula.
886893
887-
average_method : str, default='arithmetic'
888-
How to compute the normalizer in the denominator. Possible options
889-
are 'min', 'geometric', 'arithmetic', and 'max'.
894+
average_method : {'min', 'geometric', 'arithmetic', 'max'}, default='arithmetic'
895+
How to compute the normalizer in the denominator.
890896
891897
.. versionadded:: 0.20
892898

‎sklearn/tests/test_public_functions.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_public_functions.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def _check_function_param_validation(
153153
"sklearn.metrics.brier_score_loss",
154154
"sklearn.metrics.class_likelihood_ratios",
155155
"sklearn.metrics.classification_report",
156+
"sklearn.metrics.cluster.adjusted_mutual_info_score",
156157
"sklearn.metrics.cluster.contingency_matrix",
157158
"sklearn.metrics.cohen_kappa_score",
158159
"sklearn.metrics.confusion_matrix",

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.