Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f61dd6c

Browse filesBrowse files
authored
MAINT Clean-up utils.__init__: move tests into corresponding test files (scikit-learn#28842)
1 parent 0bdc754 commit f61dd6c
Copy full SHA for f61dd6c

File tree

Expand file treeCollapse file tree

4 files changed

+102
-147
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+102
-147
lines changed

‎sklearn/utils/tests/test_mask.py

Copy file name to clipboard
+19Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
3+
from sklearn.utils._mask import safe_mask
4+
from sklearn.utils.fixes import CSR_CONTAINERS
5+
from sklearn.utils.validation import check_random_state
6+
7+
8+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
9+
def test_safe_mask(csr_container):
10+
random_state = check_random_state(0)
11+
X = random_state.rand(5, 4)
12+
X_csr = csr_container(X)
13+
mask = [False, False, True, True, True]
14+
15+
mask = safe_mask(X, mask)
16+
assert X[mask].shape[0] == 3
17+
18+
mask = safe_mask(X_csr, mask)
19+
assert X_csr[mask].shape[0] == 3

‎sklearn/utils/tests/test_missing.py

Copy file name to clipboard
+27Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import numpy as np
2+
import pytest
3+
4+
from sklearn.utils._missing import is_scalar_nan
5+
6+
7+
@pytest.mark.parametrize(
8+
"value, result",
9+
[
10+
(float("nan"), True),
11+
(np.nan, True),
12+
(float(np.nan), True),
13+
(np.float32(np.nan), True),
14+
(np.float64(np.nan), True),
15+
(0, False),
16+
(0.0, False),
17+
(None, False),
18+
("", False),
19+
("nan", False),
20+
([np.nan], False),
21+
(9867966753463435747313673, False), # Python int that overflows with C type
22+
],
23+
)
24+
def test_is_scalar_nan(value, result):
25+
assert is_scalar_nan(value) is result
26+
# make sure that we are returning a Python bool
27+
assert isinstance(is_scalar_nan(value), bool)

‎sklearn/utils/tests/test_utils.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_utils.py
+1-147Lines changed: 1 addition & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -1,153 +1,7 @@
1-
import warnings
2-
31
import joblib
4-
import numpy as np
52
import pytest
63

7-
from sklearn.utils import (
8-
check_random_state,
9-
column_or_1d,
10-
deprecated,
11-
parallel_backend,
12-
register_parallel_backend,
13-
safe_mask,
14-
tosequence,
15-
)
16-
from sklearn.utils._missing import is_scalar_nan
17-
from sklearn.utils._testing import assert_array_equal
18-
from sklearn.utils.fixes import CSR_CONTAINERS
19-
from sklearn.utils.validation import _is_polars_df
20-
21-
22-
def test_make_rng():
23-
# Check the check_random_state utility function behavior
24-
assert check_random_state(None) is np.random.mtrand._rand
25-
assert check_random_state(np.random) is np.random.mtrand._rand
26-
27-
rng_42 = np.random.RandomState(42)
28-
assert check_random_state(42).randint(100) == rng_42.randint(100)
29-
30-
rng_42 = np.random.RandomState(42)
31-
assert check_random_state(rng_42) is rng_42
32-
33-
rng_42 = np.random.RandomState(42)
34-
assert check_random_state(43).randint(100) != rng_42.randint(100)
35-
36-
with pytest.raises(ValueError):
37-
check_random_state("some invalid seed")
38-
39-
40-
def test_deprecated():
41-
# Test whether the deprecated decorator issues appropriate warnings
42-
# Copied almost verbatim from https://docs.python.org/library/warnings.html
43-
44-
# First a function...
45-
with warnings.catch_warnings(record=True) as w:
46-
warnings.simplefilter("always")
47-
48-
@deprecated()
49-
def ham():
50-
return "spam"
51-
52-
spam = ham()
53-
54-
assert spam == "spam" # function must remain usable
55-
56-
assert len(w) == 1
57-
assert issubclass(w[0].category, FutureWarning)
58-
assert "deprecated" in str(w[0].message).lower()
59-
60-
# ... then a class.
61-
with warnings.catch_warnings(record=True) as w:
62-
warnings.simplefilter("always")
63-
64-
@deprecated("don't use this")
65-
class Ham:
66-
SPAM = 1
67-
68-
ham = Ham()
69-
70-
assert hasattr(ham, "SPAM")
71-
72-
assert len(w) == 1
73-
assert issubclass(w[0].category, FutureWarning)
74-
assert "deprecated" in str(w[0].message).lower()
75-
76-
77-
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
78-
def test_safe_mask(csr_container):
79-
random_state = check_random_state(0)
80-
X = random_state.rand(5, 4)
81-
X_csr = csr_container(X)
82-
mask = [False, False, True, True, True]
83-
84-
mask = safe_mask(X, mask)
85-
assert X[mask].shape[0] == 3
86-
87-
mask = safe_mask(X_csr, mask)
88-
assert X_csr[mask].shape[0] == 3
89-
90-
91-
def test_column_or_1d():
92-
EXAMPLES = [
93-
("binary", ["spam", "egg", "spam"]),
94-
("binary", [0, 1, 0, 1]),
95-
("continuous", np.arange(10) / 20.0),
96-
("multiclass", [1, 2, 3]),
97-
("multiclass", [0, 1, 2, 2, 0]),
98-
("multiclass", [[1], [2], [3]]),
99-
("multilabel-indicator", [[0, 1, 0], [0, 0, 1]]),
100-
("multiclass-multioutput", [[1, 2, 3]]),
101-
("multiclass-multioutput", [[1, 1], [2, 2], [3, 1]]),
102-
("multiclass-multioutput", [[5, 1], [4, 2], [3, 1]]),
103-
("multiclass-multioutput", [[1, 2, 3]]),
104-
("continuous-multioutput", np.arange(30).reshape((-1, 3))),
105-
]
106-
107-
for y_type, y in EXAMPLES:
108-
if y_type in ["binary", "multiclass", "continuous"]:
109-
assert_array_equal(column_or_1d(y), np.ravel(y))
110-
else:
111-
with pytest.raises(ValueError):
112-
column_or_1d(y)
113-
114-
115-
@pytest.mark.parametrize(
116-
"value, result",
117-
[
118-
(float("nan"), True),
119-
(np.nan, True),
120-
(float(np.nan), True),
121-
(np.float32(np.nan), True),
122-
(np.float64(np.nan), True),
123-
(0, False),
124-
(0.0, False),
125-
(None, False),
126-
("", False),
127-
("nan", False),
128-
([np.nan], False),
129-
(9867966753463435747313673, False), # Python int that overflows with C type
130-
],
131-
)
132-
def test_is_scalar_nan(value, result):
133-
assert is_scalar_nan(value) is result
134-
# make sure that we are returning a Python bool
135-
assert isinstance(is_scalar_nan(value), bool)
136-
137-
138-
def dummy_func():
139-
pass
140-
141-
142-
def test__is_polars_df():
143-
"""Check that _is_polars_df return False for non-dataframe objects."""
144-
145-
class LooksLikePolars:
146-
def __init__(self):
147-
self.columns = ["a", "b"]
148-
self.schema = ["a", "b"]
149-
150-
assert not _is_polars_df(LooksLikePolars())
4+
from sklearn.utils import parallel_backend, register_parallel_backend, tosequence
1515

1526

1537
# TODO(1.7): remove

‎sklearn/utils/tests/test_validation.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_validation.py
+55Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,31 @@
8080
check_is_fitted,
8181
check_memory,
8282
check_non_negative,
83+
check_random_state,
8384
check_scalar,
85+
column_or_1d,
8486
has_fit_parameter,
8587
)
8688

8789

90+
def test_make_rng():
91+
# Check the check_random_state utility function behavior
92+
assert check_random_state(None) is np.random.mtrand._rand
93+
assert check_random_state(np.random) is np.random.mtrand._rand
94+
95+
rng_42 = np.random.RandomState(42)
96+
assert check_random_state(42).randint(100) == rng_42.randint(100)
97+
98+
rng_42 = np.random.RandomState(42)
99+
assert check_random_state(rng_42) is rng_42
100+
101+
rng_42 = np.random.RandomState(42)
102+
assert check_random_state(43).randint(100) != rng_42.randint(100)
103+
104+
with pytest.raises(ValueError):
105+
check_random_state("some invalid seed")
106+
107+
88108
def test_as_float_array():
89109
# Test function for as_float_array
90110
X = np.ones((3, 10), dtype=np.int32)
@@ -2061,3 +2081,38 @@ def test_to_object_array(sequence):
20612081
assert isinstance(out, np.ndarray)
20622082
assert out.dtype.kind == "O"
20632083
assert out.ndim == 1
2084+
2085+
2086+
def test_column_or_1d():
2087+
EXAMPLES = [
2088+
("binary", ["spam", "egg", "spam"]),
2089+
("binary", [0, 1, 0, 1]),
2090+
("continuous", np.arange(10) / 20.0),
2091+
("multiclass", [1, 2, 3]),
2092+
("multiclass", [0, 1, 2, 2, 0]),
2093+
("multiclass", [[1], [2], [3]]),
2094+
("multilabel-indicator", [[0, 1, 0], [0, 0, 1]]),
2095+
("multiclass-multioutput", [[1, 2, 3]]),
2096+
("multiclass-multioutput", [[1, 1], [2, 2], [3, 1]]),
2097+
("multiclass-multioutput", [[5, 1], [4, 2], [3, 1]]),
2098+
("multiclass-multioutput", [[1, 2, 3]]),
2099+
("continuous-multioutput", np.arange(30).reshape((-1, 3))),
2100+
]
2101+
2102+
for y_type, y in EXAMPLES:
2103+
if y_type in ["binary", "multiclass", "continuous"]:
2104+
assert_array_equal(column_or_1d(y), np.ravel(y))
2105+
else:
2106+
with pytest.raises(ValueError):
2107+
column_or_1d(y)
2108+
2109+
2110+
def test__is_polars_df():
2111+
"""Check that _is_polars_df return False for non-dataframe objects."""
2112+
2113+
class LooksLikePolars:
2114+
def __init__(self):
2115+
self.columns = ["a", "b"]
2116+
self.schema = ["a", "b"]
2117+
2118+
assert not _is_polars_df(LooksLikePolars())

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.