diff --git a/sklearn/tests/test_min_dependencies_readme.py b/sklearn/tests/test_min_dependencies_readme.py index 8b2b548c5bf42..a0692d333feef 100644 --- a/sklearn/tests/test_min_dependencies_readme.py +++ b/sklearn/tests/test_min_dependencies_readme.py @@ -1,4 +1,4 @@ -"""Tests for the minimum dependencies in the README.rst file.""" +"""Tests for the minimum dependencies in README.rst and pyproject.toml""" import os @@ -50,3 +50,39 @@ def test_min_dependencies_readme(): min_version = parse_version(dependent_packages[package][0]) assert version == min_version, f"{package} has a mismatched version" + + +def test_min_dependencies_pyproject_toml(): + """Check versions in pyproject.toml is consistent with _min_dependencies.""" + # tomllib is available in Python 3.11 + tomllib = pytest.importorskip("tomllib") + + root_directory = Path(sklearn.__path__[0]).parent + pyproject_toml_path = root_directory / "pyproject.toml" + + if not pyproject_toml_path.exists(): + # Skip the test if the pyproject.toml file is not available. + # For instance, when installing scikit-learn from wheels + pytest.skip("pyproject.toml is not available.") + + with pyproject_toml_path.open("rb") as f: + pyproject_toml = tomllib.load(f) + + build_requirements = pyproject_toml["build-system"]["requires"] + + pyproject_build_min_versions = {} + for requirement in build_requirements: + if ">=" in requirement: + package, version = requirement.split(">=") + package = package.lower() + pyproject_build_min_versions[package] = version + + # Only scipy and cython are listed in pyproject.toml + # NumPy is more complex using oldest-supported-numpy. + assert set(["scipy", "cython"]) == set(pyproject_build_min_versions) + + for package, version in pyproject_build_min_versions.items(): + version = parse_version(version) + expected_min_version = parse_version(dependent_packages[package][0]) + + assert version == expected_min_version, f"{package} has a mismatched version"