From 5208a8018489a5968086b502bac97ead5e49f46e Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Wed, 5 Jun 2024 14:18:16 -0700 Subject: [PATCH] added callable function option to bicluster method --- sklearn/cluster/_bicluster.py | 10 +++++++--- sklearn/cluster/tests/test_bicluster.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sklearn/cluster/_bicluster.py b/sklearn/cluster/_bicluster.py index b22f6a369fcc1..d3af7ea2d4778 100644 --- a/sklearn/cluster/_bicluster.py +++ b/sklearn/cluster/_bicluster.py @@ -381,12 +381,13 @@ class SpectralBiclustering(BaseSpectral): The number of row and column clusters in the checkerboard structure. - method : {'bistochastic', 'scale', 'log'}, default='bistochastic' + method : {'bistochastic', 'scale', 'log'} or callable, default='bistochastic' Method of normalizing and converting singular vectors into biclusters. May be one of 'scale', 'bistochastic', or 'log'. The authors recommend using 'log'. If the data is sparse, however, log normalization will not work, which is why the default is 'bistochastic'. + Callable must take a 2D array and return a 2D array. .. warning:: if `method='log'`, the data must not be sparse. @@ -491,7 +492,7 @@ class SpectralBiclustering(BaseSpectral): _parameter_constraints: dict = { **BaseSpectral._parameter_constraints, "n_clusters": [Interval(Integral, 1, None, closed="left"), tuple], - "method": [StrOptions({"bistochastic", "scale", "log"})], + "method": [StrOptions({"bistochastic", "scale", "log"}), callable], "n_components": [Interval(Integral, 1, None, closed="left")], "n_best": [Interval(Integral, 1, None, closed="left")], } @@ -558,7 +559,10 @@ def _check_parameters(self, n_samples): def _fit(self, X): n_sv = self.n_components - if self.method == "bistochastic": + if callable(self.method): + normalized_data = self.method(X) + n_sv += 1 + elif self.method == "bistochastic": normalized_data = _bistochastic_normalize(X) n_sv += 1 elif self.method == "scale": diff --git a/sklearn/cluster/tests/test_bicluster.py b/sklearn/cluster/tests/test_bicluster.py index ebc845a7bf262..357a9cdbeb6f1 100644 --- a/sklearn/cluster/tests/test_bicluster.py +++ b/sklearn/cluster/tests/test_bicluster.py @@ -98,7 +98,7 @@ def test_spectral_biclustering(global_random_seed, csr_container): ) non_default_params = { - "method": ["scale", "log"], + "method": ["scale", "log", lambda x: x / x.sum(axis=0)], "svd_method": ["arpack"], "n_svd_vecs": [20], "mini_batch": [True],