Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

TST Change assert from sklearn to pytest style in tests/test_discriminant.py #19558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 2, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 49 additions & 22 deletions 71 sklearn/tests/test_discriminant_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@
from scipy import linalg

from sklearn.utils import check_random_state
from sklearn.utils._testing import assert_array_equal, assert_no_warnings
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import assert_allclose
from sklearn.utils._testing import assert_almost_equal
from sklearn.utils._testing import assert_raises
from sklearn.utils._testing import assert_raise_message
from sklearn.utils._testing import assert_warns
from sklearn.utils._testing import ignore_warnings

from sklearn.datasets import make_blobs
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
Expand Down Expand Up @@ -89,15 +85,22 @@ def test_lda_predict():

# Test invalid shrinkages
clf = LinearDiscriminantAnalysis(solver="lsqr", shrinkage=-0.2231)
assert_raises(ValueError, clf.fit, X, y)
with pytest.raises(ValueError):
clf.fit(X, y)

clf = LinearDiscriminantAnalysis(solver="eigen", shrinkage="dummy")
assert_raises(ValueError, clf.fit, X, y)
with pytest.raises(ValueError):
clf.fit(X, y)

clf = LinearDiscriminantAnalysis(solver="svd", shrinkage="auto")
assert_raises(NotImplementedError, clf.fit, X, y)
with pytest.raises(NotImplementedError):
clf.fit(X, y)

clf = LinearDiscriminantAnalysis(solver="lsqr", shrinkage=np.array([1, 2]))
with pytest.raises(TypeError,
match="shrinkage must be a float or a string"):
clf.fit(X, y)

clf = LinearDiscriminantAnalysis(solver="lsqr",
shrinkage=0.1,
covariance_estimator=ShrunkCovariance())
Expand All @@ -106,9 +109,11 @@ def test_lda_predict():
"parameters are not None. "
"Only one of the two can be set.")):
clf.fit(X, y)

# Test unknown solver
clf = LinearDiscriminantAnalysis(solver="dummy")
assert_raises(ValueError, clf.fit, X, y)
with pytest.raises(ValueError):
clf.fit(X, y)

# test bad solver with covariance_estimator
clf = LinearDiscriminantAnalysis(solver="svd",
Expand Down Expand Up @@ -199,7 +204,9 @@ def test_lda_priors():
priors = np.array([0.5, -0.5])
clf = LinearDiscriminantAnalysis(priors=priors)
msg = "priors must be non-negative"
assert_raise_message(ValueError, msg, clf.fit, X, y)

with pytest.raises(ValueError, match=msg):
clf.fit(X, y)

# Test that priors passed as a list are correctly handled (run to see if
# failure)
Expand All @@ -210,7 +217,10 @@ def test_lda_priors():
priors = np.array([0.5, 0.6])
prior_norm = np.array([0.45, 0.55])
clf = LinearDiscriminantAnalysis(priors=priors)
assert_warns(UserWarning, clf.fit, X, y)

with pytest.warns(UserWarning):
clf.fit(X, y)

assert_array_almost_equal(clf.priors_, prior_norm, 2)


Expand Down Expand Up @@ -247,7 +257,9 @@ def test_lda_transform():
clf = LinearDiscriminantAnalysis(solver="lsqr", n_components=1)
clf.fit(X, y)
msg = "transform not implemented for 'lsqr'"
assert_raise_message(NotImplementedError, msg, clf.transform, X)

with pytest.raises(NotImplementedError, match=msg):
clf.transform(X)


def test_lda_explained_variance_ratio():
Expand Down Expand Up @@ -424,7 +436,8 @@ def test_lda_dimension_warning(n_classes, n_features):
for n_components in [max_components - 1, None, max_components]:
# if n_components <= min(n_classes - 1, n_features), no warning
lda = LinearDiscriminantAnalysis(n_components=n_components)
assert_no_warnings(lda.fit, X, y)
with pytest.warns(None):
lda.fit(X, y)

for n_components in [max_components + 1,
max(n_features, n_classes - 1) + 1]:
Expand Down Expand Up @@ -486,7 +499,8 @@ def test_qda():
assert np.any(y_pred3 != y7)

# Classes should have at least 2 elements
assert_raises(ValueError, clf.fit, X6, y4)
with pytest.raises(ValueError):
clf.fit(X6, y4)


def test_qda_priors():
Expand Down Expand Up @@ -523,23 +537,36 @@ def test_qda_store_covariance():


def test_qda_regularization():
# the default is reg_param=0. and will cause issues
# when there is a constant variable
# The default is reg_param=0. and will cause issues when there is a
# constant variable.

# Fitting on data with constant variable triggers an UserWarning.
collinear_msg = "Variables are collinear"
clf = QuadraticDiscriminantAnalysis()
with ignore_warnings():
y_pred = clf.fit(X2, y6).predict(X2)
with pytest.warns(UserWarning, match=collinear_msg):
y_pred = clf.fit(X2, y6)

# XXX: RuntimeWarning is also raised at predict time because of divisions
# by zero when the model is fit with a constant feature and without
# regularization: should this be considered a bug? Either by the fit-time
# message more informative, raising and exception instead of a warning in
# this case or somehow changing predict to avoid division by zero.
with pytest.warns(RuntimeWarning, match="divide by zero"):
y_pred = clf.predict(X2)
assert np.any(y_pred != y6)

# adding a little regularization fixes the problem
# Adding a little regularization fixes the division by zero at predict
# time. But UserWarning will persist at fit time.
clf = QuadraticDiscriminantAnalysis(reg_param=0.01)
with ignore_warnings():
with pytest.warns(UserWarning, match=collinear_msg):
clf.fit(X2, y6)
y_pred = clf.predict(X2)
assert_array_equal(y_pred, y6)

# Case n_samples_in_a_class < n_features
# UserWarning should also be there for the n_samples_in_a_class <
# n_features case.
clf = QuadraticDiscriminantAnalysis(reg_param=0.1)
with ignore_warnings():
with pytest.warns(UserWarning, match=collinear_msg):
clf.fit(X5, y5)
y_pred5 = clf.predict(X5)
assert_array_equal(y_pred5, y5)
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.