2
2
import pytest
3
3
4
4
from sklearn .utils ._cython_blas import (
5
- ColMajor ,
6
- NoTrans ,
7
- RowMajor ,
8
- Trans ,
5
+ BLAS_Order ,
6
+ BLAS_Trans ,
9
7
_asum_memview ,
10
8
_axpy_memview ,
11
9
_copy_memview ,
@@ -30,7 +28,7 @@ def _numpy_to_cython(dtype):
30
28
31
29
32
30
RTOL = {np .float32 : 1e-6 , np .float64 : 1e-12 }
33
- ORDER = {RowMajor : "C" , ColMajor : "F" }
31
+ ORDER = {BLAS_Order . RowMajor : "C" , BLAS_Order . ColMajor : "F" }
34
32
35
33
36
34
def _no_op (x ):
@@ -166,9 +164,15 @@ def test_rot(dtype):
166
164
167
165
@pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
168
166
@pytest .mark .parametrize (
169
- "opA, transA" , [(_no_op , NoTrans ), (np .transpose , Trans )], ids = ["NoTrans" , "Trans" ]
167
+ "opA, transA" ,
168
+ [(_no_op , BLAS_Trans .NoTrans ), (np .transpose , BLAS_Trans .Trans )],
169
+ ids = ["NoTrans" , "Trans" ],
170
+ )
171
+ @pytest .mark .parametrize (
172
+ "order" ,
173
+ [BLAS_Order .RowMajor , BLAS_Order .ColMajor ],
174
+ ids = ["RowMajor" , "ColMajor" ],
170
175
)
171
- @pytest .mark .parametrize ("order" , [RowMajor , ColMajor ], ids = ["RowMajor" , "ColMajor" ])
172
176
def test_gemv (dtype , opA , transA , order ):
173
177
gemv = _gemv_memview [_numpy_to_cython (dtype )]
174
178
@@ -187,7 +191,11 @@ def test_gemv(dtype, opA, transA, order):
187
191
188
192
189
193
@pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
190
- @pytest .mark .parametrize ("order" , [RowMajor , ColMajor ], ids = ["RowMajor" , "ColMajor" ])
194
+ @pytest .mark .parametrize (
195
+ "order" ,
196
+ [BLAS_Order .RowMajor , BLAS_Order .ColMajor ],
197
+ ids = ["BLAS_Order.RowMajor" , "BLAS_Order.ColMajor" ],
198
+ )
191
199
def test_ger (dtype , order ):
192
200
ger = _ger_memview [_numpy_to_cython (dtype )]
193
201
@@ -207,12 +215,20 @@ def test_ger(dtype, order):
207
215
208
216
@pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
209
217
@pytest .mark .parametrize (
210
- "opB, transB" , [(_no_op , NoTrans ), (np .transpose , Trans )], ids = ["NoTrans" , "Trans" ]
218
+ "opB, transB" ,
219
+ [(_no_op , BLAS_Trans .NoTrans ), (np .transpose , BLAS_Trans .Trans )],
220
+ ids = ["NoTrans" , "Trans" ],
221
+ )
222
+ @pytest .mark .parametrize (
223
+ "opA, transA" ,
224
+ [(_no_op , BLAS_Trans .NoTrans ), (np .transpose , BLAS_Trans .Trans )],
225
+ ids = ["NoTrans" , "Trans" ],
211
226
)
212
227
@pytest .mark .parametrize (
213
- "opA, transA" , [(_no_op , NoTrans ), (np .transpose , Trans )], ids = ["NoTrans" , "Trans" ]
228
+ "order" ,
229
+ [BLAS_Order .RowMajor , BLAS_Order .ColMajor ],
230
+ ids = ["BLAS_Order.RowMajor" , "BLAS_Order.ColMajor" ],
214
231
)
215
- @pytest .mark .parametrize ("order" , [RowMajor , ColMajor ], ids = ["RowMajor" , "ColMajor" ])
216
232
def test_gemm (dtype , opA , transA , opB , transB , order ):
217
233
gemm = _gemm_memview [_numpy_to_cython (dtype )]
218
234
0 commit comments