Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

DEP PassiveAggressiveClassifier and PassiveAggressiveRegressor #29097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
Loading
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- `PassiveAggressiveClassifier` and `PassiveAggressiveRegressor` are deprecated
and will be removed in 1.8. Equivalent estimators are available with `SGDClassifier`
and `SGDRegressor`.
By :user:`Christian Lorentzen <lorentzenchr>`.
5 changes: 2 additions & 3 deletions 5 sklearn/feature_selection/tests/test_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
LassoCV,
LinearRegression,
LogisticRegression,
PassiveAggressiveClassifier,
SGDClassifier,
)
from sklearn.pipeline import make_pipeline
Expand Down Expand Up @@ -394,8 +393,8 @@ def test_2d_coef():


def test_partial_fit():
est = PassiveAggressiveClassifier(
random_state=0, shuffle=False, max_iter=5, tol=None
est = SGDClassifier(
random_state=0, shuffle=False, max_iter=5, tol=None, learning_rate="pa1"
)
transformer = SelectFromModel(estimator=est)
transformer.partial_fit(data, y, classes=np.unique(y))
Expand Down
45 changes: 44 additions & 1 deletion 45 sklearn/linear_model/_passive_aggressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,34 @@
from numbers import Real

from ..base import _fit_context
from ..utils import deprecated
from ..utils._param_validation import Interval, StrOptions
from ._stochastic_gradient import DEFAULT_EPSILON, BaseSGDClassifier, BaseSGDRegressor


# TODO(1.8): Remove
@deprecated( # type: ignore
"this deprecated in version 1.6 and will be removed in 1.8."
"Use `SGDClassifier` instead."
)
class PassiveAggressiveClassifier(BaseSGDClassifier):
"""Passive Aggressive Classifier.

.. deprecated:: 1.6
The whole class `PassiveAggressiveClassifier` was deprecated in version 1.6
and will be removed in 1.8. Instead use::

clf = SGDClassifier(
penalty=None,
alpha=1.0,
eta0=1.0,
learning_rate="pa1", # "pa1" and "pa2" are private
loss="hinge",
)
clf.C = 1.0 # Note that this uses a private API.

With `loss="squared_hinge"`, one would set learning_rate="pa2".

Read more in the :ref:`User Guide <passive_aggressive>`.

Parameters
Expand Down Expand Up @@ -311,9 +332,31 @@ def fit(self, X, y, coef_init=None, intercept_init=None):
)


# TODO(1.8): Remove
@deprecated( # type: ignore
"this deprecated in version 1.6 and will be removed in 1.8."
"Use `SGDRegressor` instead."
)
class PassiveAggressiveRegressor(BaseSGDRegressor):
"""Passive Aggressive Regressor.

.. deprecated:: 1.6
The whole class `PassiveAggressiveRegressor` was deprecated in version 1.6
and will be removed in 1.8. Instead use::

reg = SGDRegressor(
penalty=None,
alpha=1.0,
eta0=1.0,
l1_ratio=0,
learning_rate="pa1", # "pa1" and "pa2" are private
loss="epsilon_insensitive",
)
reg.C = 1.0 # Note that this uses a private API.

With `loss="squared_epsilon_insensitive"`, one would set learning_rate="pa2".
Use `SGDRegressor` instead.

Read more in the :ref:`User Guide <passive_aggressive>`.

Parameters
Expand Down Expand Up @@ -482,6 +525,7 @@ def __init__(
average=False,
):
super().__init__(
loss=loss,
penalty=None,
l1_ratio=0,
epsilon=epsilon,
Expand All @@ -499,7 +543,6 @@ def __init__(
average=average,
)
self.C = C
self.loss = loss

@_fit_context(prefer_skip_nested_validation=True)
def partial_fit(self, X, y):
Expand Down
12 changes: 9 additions & 3 deletions 12 sklearn/linear_model/_sgd_fast.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def _plain_sgd{{name_suffix}}(
alpha : float
The regularization parameter.
C : float
Maximum step size for passive aggressive.
Maximum step size for passive aggressive. See [1].
l1_ratio : float
The Elastic Net mixing parameter, with 0 <= l1_ratio <= 1.
l1_ratio=0 corresponds to L2 penalty, l1_ratio=1 to L1.
Expand Down Expand Up @@ -365,8 +365,8 @@ def _plain_sgd{{name_suffix}}(
(2) optimal, eta = 1.0/(alpha * t).
(3) inverse scaling, eta = eta0 / pow(t, power_t)
(4) adaptive decrease
(5) Passive Aggressive-I, eta = min(alpha, loss/norm(x))
(6) Passive Aggressive-II, eta = 1.0 / (norm(x) + 0.5*alpha)
(5) Passive Aggressive-I, eta = min(alpha, loss/norm(x)), see [1]
(6) Passive Aggressive-II, eta = 1.0 / (norm(x) + 0.5*alpha), see [1]
eta0 : double
The initial learning rate.
power_t : double
Expand Down Expand Up @@ -396,6 +396,12 @@ def _plain_sgd{{name_suffix}}(
Values are valid only if average > 0.
n_iter_ : int
The actual number of iter (epochs).

References
----------
.. [1] Online Passive-Aggressive Algorithms
<https://jmlr.org/papers/volume7/crammer06a/crammer06a.pdf>
K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
"""

# get the data information into easy vars
Expand Down
12 changes: 6 additions & 6 deletions 12 sklearn/linear_model/_stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
X,
y,
alpha=self.alpha,
C=1.0,
C=self.C,
loss=self.loss,
learning_rate=self.learning_rate,
max_iter=1,
Expand Down Expand Up @@ -933,7 +933,7 @@ def fit(self, X, y, coef_init=None, intercept_init=None, sample_weight=None):
X,
y,
alpha=self.alpha,
C=1.0,
C=self.C,
loss=self.loss,
learning_rate=self.learning_rate,
coef_init=coef_init,
Expand Down Expand Up @@ -1549,7 +1549,7 @@ def partial_fit(self, X, y, sample_weight=None):
X,
y,
self.alpha,
C=1.0,
C=self.C,
loss=self.loss,
learning_rate=self.learning_rate,
max_iter=1,
Expand Down Expand Up @@ -1653,7 +1653,7 @@ def fit(self, X, y, coef_init=None, intercept_init=None, sample_weight=None):
X,
y,
alpha=self.alpha,
C=1.0,
C=self.C,
loss=self.loss,
learning_rate=self.learning_rate,
coef_init=coef_init,
Expand Down Expand Up @@ -2478,7 +2478,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
return self._partial_fit(
X,
alpha,
C=1.0,
C=self.C,
loss=self.loss,
learning_rate=self.learning_rate,
max_iter=1,
Expand Down Expand Up @@ -2587,7 +2587,7 @@ def fit(self, X, y=None, coef_init=None, offset_init=None, sample_weight=None):
self._fit(
X,
alpha=alpha,
C=1.0,
C=self.C,
loss=self.loss,
learning_rate=self.learning_rate,
coef_init=coef_init,
Expand Down
30 changes: 29 additions & 1 deletion 30 sklearn/linear_model/tests/test_passive_aggressive.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import numpy as np
import pytest

Expand Down Expand Up @@ -69,6 +71,7 @@ def project(self, X):
return np.dot(X, self.w) + self.b


@pytest.mark.filterwarnings("ignore::FutureWarning")
@pytest.mark.parametrize("average", [False, True])
@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize("csr_container", [None, *CSR_CONTAINERS])
Expand All @@ -92,6 +95,7 @@ def test_classifier_accuracy(csr_container, fit_intercept, average):
assert hasattr(clf, "_standard_coef")


@pytest.mark.filterwarnings("ignore::FutureWarning")
@pytest.mark.parametrize("average", [False, True])
@pytest.mark.parametrize("csr_container", [None, *CSR_CONTAINERS])
def test_classifier_partial_fit(csr_container, average):
Expand All @@ -109,6 +113,7 @@ def test_classifier_partial_fit(csr_container, average):
assert hasattr(clf, "_standard_coef")


@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_classifier_refit():
# Classifier can be retrained on different labels and features.
clf = PassiveAggressiveClassifier(max_iter=5).fit(X, y)
Expand All @@ -118,6 +123,7 @@ def test_classifier_refit():
assert_array_equal(clf.classes_, iris.target_names)


@pytest.mark.filterwarnings("ignore::FutureWarning")
@pytest.mark.parametrize("csr_container", [None, *CSR_CONTAINERS])
@pytest.mark.parametrize("loss", ("hinge", "squared_hinge"))
def test_classifier_correctness(loss, csr_container):
Expand All @@ -134,6 +140,7 @@ def test_classifier_correctness(loss, csr_container):
assert_array_almost_equal(clf1.w, clf2.coef_.ravel(), decimal=2)


@pytest.mark.filterwarnings("ignore::FutureWarning")
@pytest.mark.parametrize(
"response_method", ["predict_proba", "predict_log_proba", "transform"]
)
Expand All @@ -143,6 +150,7 @@ def test_classifier_undefined_methods(response_method):
getattr(clf, response_method)


@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_class_weights():
# Test class weights.
X2 = np.array([[-1.0, -1.0], [-1.0, 0], [-0.8, -1.0], [1.0, 1.0], [1.0, 0.0]])
Expand All @@ -165,13 +173,15 @@ def test_class_weights():
assert_array_equal(clf.predict([[0.2, -1.0]]), np.array([-1]))


@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_partial_fit_weight_class_balanced():
# partial_fit with class_weight='balanced' not supported
clf = PassiveAggressiveClassifier(class_weight="balanced", max_iter=100)
with pytest.raises(ValueError):
clf.partial_fit(X, y, classes=np.unique(y))


@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_equal_class_weight():
X2 = [[1, 0], [1, 0], [0, 1], [0, 1]]
y2 = [0, 0, 1, 1]
Expand All @@ -192,6 +202,7 @@ def test_equal_class_weight():
assert_almost_equal(clf.coef_, clf_balanced.coef_, decimal=2)


@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_wrong_class_weight_label():
# ValueError due to wrong class_weight label.
X2 = np.array([[-1.0, -1.0], [-1.0, 0], [-0.8, -1.0], [1.0, 1.0], [1.0, 0.0]])
Expand All @@ -202,6 +213,7 @@ def test_wrong_class_weight_label():
clf.fit(X2, y2)


@pytest.mark.filterwarnings("ignore::FutureWarning")
@pytest.mark.parametrize("average", [False, True])
@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize("csr_container", [None, *CSR_CONTAINERS])
Expand All @@ -227,6 +239,7 @@ def test_regressor_mse(csr_container, fit_intercept, average):
assert hasattr(reg, "_standard_coef")


@pytest.mark.filterwarnings("ignore::FutureWarning")
@pytest.mark.parametrize("average", [False, True])
@pytest.mark.parametrize("csr_container", [None, *CSR_CONTAINERS])
def test_regressor_partial_fit(csr_container, average):
Expand All @@ -246,6 +259,7 @@ def test_regressor_partial_fit(csr_container, average):
assert hasattr(reg, "_standard_coef")


@pytest.mark.filterwarnings("ignore::FutureWarning")
@pytest.mark.parametrize("csr_container", [None, *CSR_CONTAINERS])
@pytest.mark.parametrize("loss", ("epsilon_insensitive", "squared_epsilon_insensitive"))
def test_regressor_correctness(loss, csr_container):
Expand All @@ -262,6 +276,7 @@ def test_regressor_correctness(loss, csr_container):
assert_array_almost_equal(reg1.w, reg2.coef_.ravel(), decimal=2)


@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_regressor_undefined_methods():
reg = PassiveAggressiveRegressor(max_iter=100)
with pytest.raises(AttributeError):
Expand All @@ -273,6 +288,19 @@ def test_regressor_undefined_methods():
"Estimator", [PassiveAggressiveClassifier, PassiveAggressiveRegressor]
)
def test_passive_aggressive_deprecated_average(Estimator):
est = Estimator(average=0)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
est = Estimator(average=0)
with pytest.warns(FutureWarning, match="average=0"):
est.fit(X, y)


# TODO(1.8): remove
@pytest.mark.parametrize(
"Estimator", [PassiveAggressiveClassifier, PassiveAggressiveRegressor]
)
def test_class_deprecation(Estimator):
# Check that we raise the proper deprecation warning.

with pytest.warns(FutureWarning, match=f"Class {Estimator.__name__} is deprecated"):
Estimator()
5 changes: 2 additions & 3 deletions 5 sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from sklearn.impute import SimpleImputer
from sklearn.linear_model import (
LogisticRegression,
PassiveAggressiveClassifier,
Ridge,
RidgeClassifier,
SGDClassifier,
Expand Down Expand Up @@ -1363,7 +1362,7 @@ def test_learning_curve_batch_and_incremental_learning_are_equal():
random_state=0,
)
train_sizes = np.linspace(0.2, 1.0, 5)
estimator = PassiveAggressiveClassifier(max_iter=1, tol=None, shuffle=False)
estimator = SGDClassifier(max_iter=1, tol=None, shuffle=False)

train_sizes_inc, train_scores_inc, test_scores_inc = learning_curve(
estimator,
Expand Down Expand Up @@ -1482,7 +1481,7 @@ def test_learning_curve_with_shuffle():
groups = np.array([1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 4, 4, 4, 4])
# Splits on these groups fail without shuffle as the first iteration
# of the learning curve doesn't contain label 4 in the training set.
estimator = PassiveAggressiveClassifier(max_iter=5, tol=None, shuffle=False)
estimator = SGDClassifier(max_iter=5, tol=None, shuffle=False, learning_rate="pa1")

cv = GroupKFold(n_splits=2)
train_sizes_batch, train_scores_batch, test_scores_batch = learning_curve(
Expand Down
3 changes: 1 addition & 2 deletions 3 sklearn/tests/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
LinearRegression,
LogisticRegression,
OrthogonalMatchingPursuit,
PassiveAggressiveClassifier,
Ridge,
SGDClassifier,
SGDRegressor,
Expand Down Expand Up @@ -851,7 +850,7 @@ def test_fit_params_no_routing(Cls, method):
underlying classifier.
"""
X, y = make_classification(n_samples=50)
clf = Cls(PassiveAggressiveClassifier())
clf = Cls(SGDClassifier())

with pytest.raises(ValueError, match="is only supported if"):
getattr(clf, method)(X, y, test=1)
Expand Down
5 changes: 5 additions & 0 deletions 5 sklearn/utils/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import functools
import warnings
from inspect import signature

__all__ = ["deprecated"]

Expand Down Expand Up @@ -64,17 +65,21 @@ def _decorate_class(self, cls):
msg += "; %s" % self.extra

new = cls.__new__
sig = signature(cls)

def wrapped(cls, *args, **kwargs):
warnings.warn(msg, category=FutureWarning)
if new is object.__new__:
return object.__new__(cls)

return new(cls, *args, **kwargs)

cls.__new__ = wrapped

wrapped.__name__ = "__new__"
wrapped.deprecated_original = new
# Restore the original signature, see PEP 362.
cls.__signature__ = sig

return cls

Expand Down
Loading
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.