From 0a15321f98a10cbd792a002bbb31b59b4ea10528 Mon Sep 17 00:00:00 2001 From: Shivachauhan17 Date: Tue, 28 Feb 2023 11:41:53 +0530 Subject: [PATCH 1/3] add parameter validation to dump_svmlight_file --- sklearn/datasets/_svmlight_format_io.py | 11 ++++++++++- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sklearn/datasets/_svmlight_format_io.py b/sklearn/datasets/_svmlight_format_io.py index 2a141e1732ff7..efdbc416922bd 100644 --- a/sklearn/datasets/_svmlight_format_io.py +++ b/sklearn/datasets/_svmlight_format_io.py @@ -25,6 +25,7 @@ from .. import __version__ from ..utils import check_array, IS_PYPY +from ..utils._param_validation import validate_params,StrOptions if not IS_PYPY: from ._svmlight_format_fast import ( @@ -403,7 +404,15 @@ def _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id): y_is_sp, ) - +@validate_params({ + "X":["array-like","sparse matrix"], + "y":["array-like","sparse matrix"], + "f":[str,StrOptions({"file"})], + "zero_based":[bool,True], + "comment":[str,None], + "query_id":["array-like",None], + "multilabel":[bool,False], +}) def dump_svmlight_file( X, y, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index dae1fdb2e6164..d2e60f37f2e57 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -107,6 +107,7 @@ def _check_function_param_validation( "sklearn.datasets.make_classification", "sklearn.datasets.make_friedman1", "sklearn.datasets.make_sparse_coded_signal", + "sklearn.datasets.dump_svmlight_file", "sklearn.decomposition.sparse_encode", "sklearn.feature_extraction.grid_to_graph", "sklearn.feature_extraction.img_to_graph", From 2873f244fab26004e0271882ed9e88409f11a77e Mon Sep 17 00:00:00 2001 From: Shivachauhan17 Date: Wed, 1 Mar 2023 16:27:18 +0530 Subject: [PATCH 2/3] add parameter validation to fetch_lfw_people --- sklearn/datasets/_lfw.py | 15 +++++++++++++-- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/sklearn/datasets/_lfw.py b/sklearn/datasets/_lfw.py index 7252c050bef3c..d7dc9017ee32a 100644 --- a/sklearn/datasets/_lfw.py +++ b/sklearn/datasets/_lfw.py @@ -10,7 +10,7 @@ from os import listdir, makedirs, remove from os.path import join, exists, isdir - +from ..utils._param_validation import validate_params import logging import numpy as np @@ -230,7 +230,18 @@ def _fetch_lfw_people( faces, target = faces[indices], target[indices] return faces, target, target_names - +@validate_params( + { + "data_home":[str,None], + "funneled":["boolean"], + "resize":[float,None], + "min_faces_per_person":[int,None], + "color":["boolean"], + "slice":["tuple of slice",(slice(0, 250), slice(0, 250))], + "download_if_missing":["boolean"], + "return_X_y":["boolean"], + } +) def fetch_lfw_people( *, data_home=None, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index d2e60f37f2e57..021fff575dd20 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -108,6 +108,7 @@ def _check_function_param_validation( "sklearn.datasets.make_friedman1", "sklearn.datasets.make_sparse_coded_signal", "sklearn.datasets.dump_svmlight_file", + "sklearn.datasets.fetch_lfw_people", "sklearn.decomposition.sparse_encode", "sklearn.feature_extraction.grid_to_graph", "sklearn.feature_extraction.img_to_graph", From 2253c440f9dcbe19b4b0309046a1734360cddab3 Mon Sep 17 00:00:00 2001 From: Shiva chauhan <103742975+Shivachauhan17@users.noreply.github.com> Date: Wed, 1 Mar 2023 16:38:34 +0530 Subject: [PATCH 3/3] lint --- sklearn/datasets/_lfw.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/datasets/_lfw.py b/sklearn/datasets/_lfw.py index d7dc9017ee32a..b336b935bef41 100644 --- a/sklearn/datasets/_lfw.py +++ b/sklearn/datasets/_lfw.py @@ -10,7 +10,7 @@ from os import listdir, makedirs, remove from os.path import join, exists, isdir -from ..utils._param_validation import validate_params +from ..utils._param_validation import validate_params,HasMethods import logging import numpy as np