Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e3d1f9a

Browse filesBrowse files
author
Théophile Baranger
authored
MAINT Parameters validation for datasets.make_multilabel_classification (#25920)
1 parent 01f8d34 commit e3d1f9a
Copy full SHA for e3d1f9a

File tree

3 files changed

+15
-26
lines changed
Filter options

3 files changed

+15
-26
lines changed

‎sklearn/datasets/_samples_generator.py

Copy file name to clipboardExpand all lines: sklearn/datasets/_samples_generator.py
+14-14Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,20 @@ def make_classification(
309309
return X, y
310310

311311

312+
@validate_params(
313+
{
314+
"n_samples": [Interval(Integral, 1, None, closed="left")],
315+
"n_features": [Interval(Integral, 1, None, closed="left")],
316+
"n_classes": [Interval(Integral, 1, None, closed="left")],
317+
"n_labels": [Interval(Integral, 0, None, closed="left")],
318+
"length": [Interval(Integral, 1, None, closed="left")],
319+
"allow_unlabeled": ["boolean"],
320+
"sparse": ["boolean"],
321+
"return_indicator": [StrOptions({"dense", "sparse"}), "boolean"],
322+
"return_distributions": ["boolean"],
323+
"random_state": ["random_state"],
324+
}
325+
)
312326
def make_multilabel_classification(
313327
n_samples=100,
314328
n_features=20,
@@ -398,18 +412,6 @@ def make_multilabel_classification(
398412
The probability of each feature being drawn given each class.
399413
Only returned if ``return_distributions=True``.
400414
"""
401-
if n_classes < 1:
402-
raise ValueError(
403-
"'n_classes' should be an integer greater than 0. Got {} instead.".format(
404-
n_classes
405-
)
406-
)
407-
if length < 1:
408-
raise ValueError(
409-
"'length' should be an integer greater than 0. Got {} instead.".format(
410-
length
411-
)
412-
)
413415

414416
generator = check_random_state(random_state)
415417
p_c = generator.uniform(size=n_classes)
@@ -469,8 +471,6 @@ def sample_example():
469471
if return_indicator in (True, "sparse", "dense"):
470472
lb = MultiLabelBinarizer(sparse_output=(return_indicator == "sparse"))
471473
Y = lb.fit([range(n_classes)]).transform(Y)
472-
elif return_indicator is not False:
473-
raise ValueError("return_indicator must be either 'sparse', 'dense' or False.")
474474
if return_distributions:
475475
return X, Y, p_c, p_w_c
476476
return X, Y

‎sklearn/datasets/tests/test_samples_generator.py

Copy file name to clipboardExpand all lines: sklearn/datasets/tests/test_samples_generator.py
-12Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -283,18 +283,6 @@ def test_make_multilabel_classification_return_indicator_sparse():
283283
assert sp.issparse(Y)
284284

285285

286-
@pytest.mark.parametrize(
287-
"params, err_msg",
288-
[
289-
({"n_classes": 0}, "'n_classes' should be an integer"),
290-
({"length": 0}, "'length' should be an integer"),
291-
],
292-
)
293-
def test_make_multilabel_classification_valid_arguments(params, err_msg):
294-
with pytest.raises(ValueError, match=err_msg):
295-
make_multilabel_classification(**params)
296-
297-
298286
def test_make_hastie_10_2():
299287
X, y = make_hastie_10_2(n_samples=100, random_state=0)
300288
assert X.shape == (100, 10), "X shape mismatch"

‎sklearn/tests/test_public_functions.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_public_functions.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def _check_function_param_validation(
133133
"sklearn.datasets.make_classification",
134134
"sklearn.datasets.make_friedman1",
135135
"sklearn.datasets.make_low_rank_matrix",
136+
"sklearn.datasets.make_multilabel_classification",
136137
"sklearn.datasets.make_regression",
137138
"sklearn.datasets.make_sparse_coded_signal",
138139
"sklearn.decomposition.sparse_encode",

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.