Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Make automatic validation for all scikit-learn public functions #24862

Copy link
Copy link
Closed
@glemaitre

Description

@glemaitre
Issue body actions

PR #22722 introduced a decorator to validate the parameters of functions. We now need to use it for all functions where it is applicable.

Please open one PR per function. The title of the PR must mention which function it's dealing with. We recommend using the following pattern for titles:

MAINT Parameters validation for <function>

where <function> is a placeholder to be replaced with the function you chose.

The description of the PR must begin with Towards #24862 so that this issue and the PR are mutually crossed-linked.

Steps

  1. Chose a public function that is documented in https://scikit-learn.org/dev/modules/classes.html. Check in the source code if the function contains some manual parameter validation (i.e. you should see some if condition and error raising). In case there is no validation in the function, you can report it in the issue where we will decide whether or not to skip the function.

  2. To validate the function, you need to decorate it with the decorator sklearn.utils._param_validation.validate_params. Do not rely only on the docstring of the estimator to define it: although it can help, it's important to primarily rely on the implementation to find the valid values because the docstring might not be completely accurate. The decorator take a Python dictionary as input where each key corresponds to a parameter name and the value corresponds to the associate constraints. You can find an example for kmeans_plusplus below

    @validate_params(
    {
    "X": ["array-like", "sparse matrix"],
    "n_clusters": [Interval(Integral, 1, None, closed="left")],
    "x_squared_norms": ["array-like", None],
    "random_state": ["random_state"],
    "n_local_trials": [Interval(Integral, 1, None, closed="left"), None],
    }
    )
    def kmeans_plusplus(
    X, n_clusters, *, x_squared_norms=None, random_state=None, n_local_trials=None
    ):
    You can also get more details regarding the constraints by looking at the different Estimators validation previously implemented (cf. the _parameter_constraints attribute).

  3. All existing simple param validation can now be removed. (simple means that does not depend on the input data or that does not depend on the value of another parameter for instance).

  4. Tests that check error messages from simple param validation can also be removed (carefully: we need to keep the tests checking for more complex param validation !).

  5. Finally, add the function to the list of the common param validation test

    PARAM_VALIDATION_FUNCTION_LIST = [
    "sklearn.cluster.kmeans_plusplus",
    ]

    and make sure the test passes: pytest -vl sklearn/tests/test_public_functions.py

Functions already updated:

See "details" in section 1

Be aware that you can see an up-to-date list at the following link: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/tests/test_public_functions.py#L132

Metadata

Metadata

Assignees

No one assigned

    Labels

    Meta-issueGeneral issue associated to an identified list of tasksGeneral issue associated to an identified list of tasksSprintValidationrelated to input validationrelated to input validationgood first issueEasy with clear instructions to resolveEasy with clear instructions to resolve

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      Morty Proxy This is a proxified and sanitized view of the page, visit original site.