Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f2773e8

Browse filesBrowse files
azihnaAlihan Zihna
and
Alihan Zihna
authored
TST replace assert_raise_* by pytest.raises in tests/test_multioutput.py (#19618)
Co-authored-by: Alihan Zihna <a.zihna@ckhgbdp.onmicrosoft.com>
1 parent 42e90e9 commit f2773e8
Copy full SHA for f2773e8

File tree

Expand file treeCollapse file tree

1 file changed

+23
-13
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+23
-13
lines changed

‎sklearn/tests/test_multioutput.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_multioutput.py
+23-13Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
from joblib import cpu_count
66

77
from sklearn.utils._testing import assert_almost_equal
8-
from sklearn.utils._testing import assert_raises
9-
from sklearn.utils._testing import assert_raises_regex
10-
from sklearn.utils._testing import assert_raise_message
118
from sklearn.utils._testing import assert_array_equal
129
from sklearn.utils._testing import assert_array_almost_equal
1310
from sklearn import datasets
@@ -80,7 +77,9 @@ def test_multi_target_regression_one_target():
8077
# Test multi target regression raises
8178
X, y = datasets.make_regression(n_targets=1)
8279
rgr = MultiOutputRegressor(GradientBoostingRegressor(random_state=0))
83-
assert_raises(ValueError, rgr.fit, X, y)
80+
msg = 'at least two dimensions'
81+
with pytest.raises(ValueError, match=msg):
82+
rgr.fit(X, y)
8483

8584

8685
def test_multi_target_sparse_regression():
@@ -106,8 +105,9 @@ def test_multi_target_sample_weights_api():
106105
w = [0.8, 0.6]
107106

108107
rgr = MultiOutputRegressor(OrthogonalMatchingPursuit())
109-
assert_raises_regex(ValueError, "does not support sample weights",
110-
rgr.fit, X, y, w)
108+
msg = "does not support sample weights"
109+
with pytest.raises(ValueError, match=msg):
110+
rgr.fit(X, y, w)
111111

112112
# no exception should be raised if the base estimator supports weights
113113
rgr = MultiOutputRegressor(GradientBoostingRegressor(random_state=0))
@@ -252,9 +252,9 @@ def test_multi_output_classification_partial_fit():
252252
def test_multi_output_classification_partial_fit_no_first_classes_exception():
253253
sgd_linear_clf = SGDClassifier(loss='log', random_state=1, max_iter=5)
254254
multi_target_linear = MultiOutputClassifier(sgd_linear_clf)
255-
assert_raises_regex(ValueError, "classes must be passed on the first call "
256-
"to partial_fit.",
257-
multi_target_linear.partial_fit, X, y)
255+
msg = "classes must be passed on the first call to partial_fit."
256+
with pytest.raises(ValueError, match=msg):
257+
multi_target_linear.partial_fit(X, y)
258258

259259

260260
def test_multi_output_classification():
@@ -386,17 +386,27 @@ def test_multi_output_exceptions():
386386
# NotFittedError when fit is not done but score, predict and
387387
# and predict_proba are called
388388
moc = MultiOutputClassifier(LinearSVC(random_state=0))
389-
assert_raises(NotFittedError, moc.predict, y)
389+
390+
with pytest.raises(NotFittedError):
391+
moc.predict(y)
392+
390393
with pytest.raises(NotFittedError):
391394
moc.predict_proba
392-
assert_raises(NotFittedError, moc.score, X, y)
395+
396+
with pytest.raises(NotFittedError):
397+
moc.score(X, y)
398+
393399
# ValueError when number of outputs is different
394400
# for fit and score
395401
y_new = np.column_stack((y1, y2))
396402
moc.fit(X, y)
397-
assert_raises(ValueError, moc.score, X, y_new)
403+
with pytest.raises(ValueError):
404+
moc.score(X, y_new)
405+
398406
# ValueError when y is continuous
399-
assert_raise_message(ValueError, "Unknown label type", moc.fit, X, X[:, 1])
407+
msg = "Unknown label type"
408+
with pytest.raises(ValueError, match=msg):
409+
moc.fit(X, X[:, 1])
400410

401411

402412
def generate_multilabel_dataset_with_correlations():

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.