Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit bc3a19d

Browse filesBrowse files
MAINT Parameters validation for sklearn.datasets.load_files (#26203)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 18a4576 commit bc3a19d
Copy full SHA for bc3a19d

File tree

3 files changed

+18
-3
lines changed
Filter options

3 files changed

+18
-3
lines changed

‎sklearn/datasets/_base.py

Copy file name to clipboardExpand all lines: sklearn/datasets/_base.py
+14-1Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ..utils import check_random_state
2323
from ..utils import check_pandas_support
2424
from ..utils.fixes import _open_binary, _open_text, _read_text, _contents
25-
from ..utils._param_validation import validate_params, Interval
25+
from ..utils._param_validation import validate_params, Interval, StrOptions
2626

2727
import numpy as np
2828

@@ -104,6 +104,19 @@ def _convert_data_dataframe(
104104
return combined_df, X, y
105105

106106

107+
@validate_params(
108+
{
109+
"container_path": [str, os.PathLike],
110+
"description": [str, None],
111+
"categories": [list, None],
112+
"load_content": ["boolean"],
113+
"shuffle": ["boolean"],
114+
"encoding": [str, None],
115+
"decode_error": [StrOptions({"strict", "ignore", "replace"})],
116+
"random_state": ["random_state"],
117+
"allowed_extensions": [list, None],
118+
}
119+
)
107120
def load_files(
108121
container_path,
109122
*,

‎sklearn/datasets/tests/test_base.py

Copy file name to clipboardExpand all lines: sklearn/datasets/tests/test_base.py
+3-2Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,11 @@ def test_default_load_files(test_category_dir_1, test_category_dir_2, load_files
9898
def test_load_files_w_categories_desc_and_encoding(
9999
test_category_dir_1, test_category_dir_2, load_files_root
100100
):
101-
category = os.path.abspath(test_category_dir_1).split("/").pop()
101+
category = os.path.abspath(test_category_dir_1).split(os.sep).pop()
102102
res = load_files(
103-
load_files_root, description="test", categories=category, encoding="utf-8"
103+
load_files_root, description="test", categories=[category], encoding="utf-8"
104104
)
105+
105106
assert len(res.filenames) == 1
106107
assert len(res.target_names) == 1
107108
assert res.DESCR == "test"

‎sklearn/tests/test_public_functions.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_public_functions.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _check_function_param_validation(
134134
"sklearn.datasets.load_breast_cancer",
135135
"sklearn.datasets.load_diabetes",
136136
"sklearn.datasets.load_digits",
137+
"sklearn.datasets.load_files",
137138
"sklearn.datasets.load_iris",
138139
"sklearn.datasets.load_linnerud",
139140
"sklearn.datasets.load_svmlight_file",

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.