Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit aa1e69a

Browse filesBrowse files
authored
API Removes the use of fit_ and partial_fit_ in Birch (#19297)
* API Removes the use of fit_ and partial_fit_ in Birch * DOC Adds whats new * ENH Adjust names * CLN Uses a verbose name
1 parent b943324 commit aa1e69a
Copy full SHA for aa1e69a

File tree

3 files changed

+42
-6
lines changed
Filter options

3 files changed

+42
-6
lines changed

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ Changelog
5656
in multicore settings. :pr:`19052` by
5757
:user:`Yusuke Nagasaka <YusukeNagasaka>`.
5858

59+
- |API| :class:`cluster.Birch` attributes, `fit_` and `partial_fit_`, are
60+
deprecated and will be removed in 1.2. :pr:`19297` by `Thomas Fan`_.
61+
5962
- |Fix| Fixes incorrect multiple data-conversion warnings when clustering
6063
boolean data. :pr:`19046` by :user:`Surya Prakash <jdsurya>`.
6164

‎sklearn/cluster/_birch.py

Copy file name to clipboardExpand all lines: sklearn/cluster/_birch.py
+27-6Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ..metrics.pairwise import euclidean_distances
1414
from ..base import TransformerMixin, ClusterMixin, BaseEstimator
1515
from ..utils.extmath import row_norms
16+
from ..utils import deprecated
1617
from ..utils.validation import check_is_fitted, _deprecate_positional_args
1718
from ..exceptions import ConvergenceWarning
1819
from . import AgglomerativeClustering
@@ -440,6 +441,24 @@ def __init__(self, *, threshold=0.5, branching_factor=50, n_clusters=3,
440441
self.compute_labels = compute_labels
441442
self.copy = copy
442443

444+
# TODO: Remove in 1.2
445+
# mypy error: Decorated property not supported
446+
@deprecated( # type: ignore
447+
"fit_ is deprecated in 1.0 and will be removed in 1.2"
448+
)
449+
@property
450+
def fit_(self):
451+
return self._deprecated_fit
452+
453+
# TODO: Remove in 1.2
454+
# mypy error: Decorated property not supported
455+
@deprecated( # type: ignore
456+
"partial_fit_ is deprecated in 1.0 and will be removed in 1.2"
457+
)
458+
@property
459+
def partial_fit_(self):
460+
return self._deprecated_partial_fit
461+
443462
def fit(self, X, y=None):
444463
"""
445464
Build a CF Tree for the input data.
@@ -457,12 +476,13 @@ def fit(self, X, y=None):
457476
self
458477
Fitted estimator.
459478
"""
460-
self.fit_, self.partial_fit_ = True, False
461-
return self._fit(X)
479+
# TODO: Remove deprected flags in 1.2
480+
self._deprecated_fit, self._deprecated_partial_fit = True, False
481+
return self._fit(X, partial=False)
462482

463-
def _fit(self, X):
483+
def _fit(self, X, partial):
464484
has_root = getattr(self, 'root_', None)
465-
first_call = self.fit_ or (self.partial_fit_ and not has_root)
485+
first_call = not (partial and has_root)
466486

467487
X = self._validate_data(X, accept_sparse='csr', copy=self.copy,
468488
reset=first_call)
@@ -552,13 +572,14 @@ def partial_fit(self, X=None, y=None):
552572
self
553573
Fitted estimator.
554574
"""
555-
self.partial_fit_, self.fit_ = True, False
575+
# TODO: Remove deprected flags in 1.2
576+
self._deprecated_partial_fit, self._deprecated_fit = True, False
556577
if X is None:
557578
# Perform just the final global clustering step.
558579
self._global_clustering()
559580
return self
560581
else:
561-
return self._fit(X)
582+
return self._fit(X, partial=True)
562583

563584
def _check_fit(self, X):
564585
check_is_fitted(self)

‎sklearn/cluster/tests/test_birch.py

Copy file name to clipboardExpand all lines: sklearn/cluster/tests/test_birch.py
+12Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,15 @@ def test_birch_n_clusters_long_int():
179179
X, _ = make_blobs(random_state=0)
180180
n_clusters = np.int64(5)
181181
Birch(n_clusters=n_clusters).fit(X)
182+
183+
184+
# TODO: Remove in 1.2
185+
@pytest.mark.parametrize("attribute", ["fit_", "partial_fit_"])
186+
def test_birch_fit_attributes_deprecated(attribute):
187+
"""Test that fit_ and partial_fit_ attributes are deprecated."""
188+
msg = f"{attribute} is deprecated in 1.0 and will be removed in 1.2"
189+
X, y = make_blobs(n_samples=10)
190+
brc = Birch().fit(X, y)
191+
192+
with pytest.warns(FutureWarning, match=msg):
193+
getattr(brc, attribute)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.