Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 05c0e08

Browse filesBrowse files
cozekKaushik Amar Dasglemaitre
authored
ENH BaseLabelPropagation to accept sparse matrices (scikit-learn#19664)
Co-authored-by: Kaushik Amar Das <kaushik.amar.das@accenture.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent baefe83 commit 05c0e08
Copy full SHA for 05c0e08

File tree

3 files changed

+39
-6
lines changed
Filter options

3 files changed

+39
-6
lines changed

‎doc/whats_new/v1.3.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.3.rst
+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,13 @@ Changelog
214214
during `transform` with no prior call to `fit` or `fit_transform`.
215215
:pr:`25190` by :user:`Vincent Maladière <Vincent-Maladiere>`.
216216

217+
:mod:`sklearn.semi_supervised`
218+
..............................
219+
220+
- |Enhancement| :meth:`LabelSpreading.fit` and :meth:`LabelPropagation.fit` now
221+
accepts sparse metrics.
222+
:pr:`19664` by :user:`Kaushik Amar Das <cozek>`.
223+
217224
Code and Documentation Contributors
218225
-----------------------------------
219226

‎sklearn/semi_supervised/_label_propagation.py

Copy file name to clipboardExpand all lines: sklearn/semi_supervised/_label_propagation.py
+9-4Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def fit(self, X, y):
241241
242242
Parameters
243243
----------
244-
X : array-like of shape (n_samples, n_features)
244+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
245245
Training data, where `n_samples` is the number of samples
246246
and `n_features` is the number of features.
247247
@@ -256,7 +256,12 @@ def fit(self, X, y):
256256
Returns the instance itself.
257257
"""
258258
self._validate_params()
259-
X, y = self._validate_data(X, y)
259+
X, y = self._validate_data(
260+
X,
261+
y,
262+
accept_sparse=["csr", "csc"],
263+
reset=True,
264+
)
260265
self.X_ = X
261266
check_classification_targets(y)
262267

@@ -365,7 +370,7 @@ class LabelPropagation(BaseLabelPropagation):
365370
366371
Attributes
367372
----------
368-
X_ : ndarray of shape (n_samples, n_features)
373+
X_ : {array-like, sparse matrix} of shape (n_samples, n_features)
369374
Input array.
370375
371376
classes_ : ndarray of shape (n_classes,)
@@ -463,7 +468,7 @@ def fit(self, X, y):
463468
464469
Parameters
465470
----------
466-
X : array-like of shape (n_samples, n_features)
471+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
467472
Training data, where `n_samples` is the number of samples
468473
and `n_features` is the number of features.
469474

‎sklearn/semi_supervised/tests/test_label_propagation.py

Copy file name to clipboardExpand all lines: sklearn/semi_supervised/tests/test_label_propagation.py
+23-2Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
assert_allclose,
1616
assert_array_equal,
1717
)
18+
from sklearn.utils._testing import _convert_container
19+
20+
CONSTRUCTOR_TYPES = ("array", "sparse_csr", "sparse_csc")
1821

1922
ESTIMATORS = [
2023
(label_propagation.LabelPropagation, {"kernel": "rbf"}),
@@ -122,9 +125,27 @@ def test_label_propagation_closed_form(global_dtype):
122125
assert_allclose(expected, clf.label_distributions_, atol=1e-4)
123126

124127

125-
def test_convergence_speed():
128+
@pytest.mark.parametrize("accepted_sparse_type", ["sparse_csr", "sparse_csc"])
129+
@pytest.mark.parametrize("index_dtype", [np.int32, np.int64])
130+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
131+
@pytest.mark.parametrize("Estimator, parameters", ESTIMATORS)
132+
def test_sparse_input_types(
133+
accepted_sparse_type, index_dtype, dtype, Estimator, parameters
134+
):
135+
# This is non-regression test for #17085
136+
X = _convert_container([[1.0, 0.0], [0.0, 2.0], [1.0, 3.0]], accepted_sparse_type)
137+
X.data = X.data.astype(dtype, copy=False)
138+
X.indices = X.indices.astype(index_dtype, copy=False)
139+
X.indptr = X.indptr.astype(index_dtype, copy=False)
140+
labels = [0, 1, -1]
141+
clf = Estimator(**parameters).fit(X, labels)
142+
assert_array_equal(clf.predict([[0.5, 2.5]]), np.array([1]))
143+
144+
145+
@pytest.mark.parametrize("constructor_type", CONSTRUCTOR_TYPES)
146+
def test_convergence_speed(constructor_type):
126147
# This is a non-regression test for #5774
127-
X = np.array([[1.0, 0.0], [0.0, 1.0], [1.0, 2.5]])
148+
X = _convert_container([[1.0, 0.0], [0.0, 1.0], [1.0, 2.5]], constructor_type)
128149
y = np.array([0, 1, -1])
129150
mdl = label_propagation.LabelSpreading(kernel="rbf", max_iter=5000)
130151
mdl.fit(X, y)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.