diff --git a/sklearn/cluster/_spectral.py b/sklearn/cluster/_spectral.py index c77e4494fcc26..8fdd47300b2d9 100644 --- a/sklearn/cluster/_spectral.py +++ b/sklearn/cluster/_spectral.py @@ -15,7 +15,7 @@ from scipy.sparse import csc_matrix from ..base import BaseEstimator, ClusterMixin -from ..utils._param_validation import Interval, StrOptions +from ..utils._param_validation import Interval, StrOptions, validate_params from ..utils import check_random_state, as_float_array from ..metrics.pairwise import pairwise_kernels, KERNEL_PARAMS from ..neighbors import kneighbors_graph, NearestNeighbors @@ -190,6 +190,7 @@ def discretize( return labels +@validate_params({"affinity": ["array-like", "sparse matrix"]}) def spectral_clustering( affinity, *, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 4014b03607ee3..df2bd6346c9bc 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -151,6 +151,7 @@ def test_function_param_validation(func_module): PARAM_VALIDATION_CLASS_WRAPPER_LIST = [ ("sklearn.cluster.affinity_propagation", "sklearn.cluster.AffinityPropagation"), + ("sklearn.cluster.spectral_clustering", "sklearn.cluster.SpectralClustering"), ("sklearn.covariance.ledoit_wolf", "sklearn.covariance.LedoitWolf"), ("sklearn.covariance.oas", "sklearn.covariance.OAS"), ("sklearn.decomposition.dict_learning", "sklearn.decomposition.DictionaryLearning"),