Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 2d66fd7

Browse filesBrowse files
authored
Fix BLAS_Order.RowMajor import and similar in test_cython_blas with Cython 3.1 (scikit-learn#31301)
1 parent f0c80e8 commit 2d66fd7
Copy full SHA for 2d66fd7

File tree

1 file changed

+27
-11
lines changed
Filter options

1 file changed

+27
-11
lines changed

‎sklearn/utils/tests/test_cython_blas.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_cython_blas.py
+27-11Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
import pytest
33

44
from sklearn.utils._cython_blas import (
5-
ColMajor,
6-
NoTrans,
7-
RowMajor,
8-
Trans,
5+
BLAS_Order,
6+
BLAS_Trans,
97
_asum_memview,
108
_axpy_memview,
119
_copy_memview,
@@ -30,7 +28,7 @@ def _numpy_to_cython(dtype):
3028

3129

3230
RTOL = {np.float32: 1e-6, np.float64: 1e-12}
33-
ORDER = {RowMajor: "C", ColMajor: "F"}
31+
ORDER = {BLAS_Order.RowMajor: "C", BLAS_Order.ColMajor: "F"}
3432

3533

3634
def _no_op(x):
@@ -166,9 +164,15 @@ def test_rot(dtype):
166164

167165
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
168166
@pytest.mark.parametrize(
169-
"opA, transA", [(_no_op, NoTrans), (np.transpose, Trans)], ids=["NoTrans", "Trans"]
167+
"opA, transA",
168+
[(_no_op, BLAS_Trans.NoTrans), (np.transpose, BLAS_Trans.Trans)],
169+
ids=["NoTrans", "Trans"],
170+
)
171+
@pytest.mark.parametrize(
172+
"order",
173+
[BLAS_Order.RowMajor, BLAS_Order.ColMajor],
174+
ids=["RowMajor", "ColMajor"],
170175
)
171-
@pytest.mark.parametrize("order", [RowMajor, ColMajor], ids=["RowMajor", "ColMajor"])
172176
def test_gemv(dtype, opA, transA, order):
173177
gemv = _gemv_memview[_numpy_to_cython(dtype)]
174178

@@ -187,7 +191,11 @@ def test_gemv(dtype, opA, transA, order):
187191

188192

189193
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
190-
@pytest.mark.parametrize("order", [RowMajor, ColMajor], ids=["RowMajor", "ColMajor"])
194+
@pytest.mark.parametrize(
195+
"order",
196+
[BLAS_Order.RowMajor, BLAS_Order.ColMajor],
197+
ids=["BLAS_Order.RowMajor", "BLAS_Order.ColMajor"],
198+
)
191199
def test_ger(dtype, order):
192200
ger = _ger_memview[_numpy_to_cython(dtype)]
193201

@@ -207,12 +215,20 @@ def test_ger(dtype, order):
207215

208216
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
209217
@pytest.mark.parametrize(
210-
"opB, transB", [(_no_op, NoTrans), (np.transpose, Trans)], ids=["NoTrans", "Trans"]
218+
"opB, transB",
219+
[(_no_op, BLAS_Trans.NoTrans), (np.transpose, BLAS_Trans.Trans)],
220+
ids=["NoTrans", "Trans"],
221+
)
222+
@pytest.mark.parametrize(
223+
"opA, transA",
224+
[(_no_op, BLAS_Trans.NoTrans), (np.transpose, BLAS_Trans.Trans)],
225+
ids=["NoTrans", "Trans"],
211226
)
212227
@pytest.mark.parametrize(
213-
"opA, transA", [(_no_op, NoTrans), (np.transpose, Trans)], ids=["NoTrans", "Trans"]
228+
"order",
229+
[BLAS_Order.RowMajor, BLAS_Order.ColMajor],
230+
ids=["BLAS_Order.RowMajor", "BLAS_Order.ColMajor"],
214231
)
215-
@pytest.mark.parametrize("order", [RowMajor, ColMajor], ids=["RowMajor", "ColMajor"])
216232
def test_gemm(dtype, opA, transA, opB, transB, order):
217233
gemm = _gemm_memview[_numpy_to_cython(dtype)]
218234

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.