Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 16023ac

Browse filesBrowse files
committed
ENH: define matvec and vecmat gufuncs
Internally, they mostly just call the relevant matmul, blas, or vecdot routines.
1 parent cf2d77a commit 16023ac
Copy full SHA for 16023ac

File tree

Expand file treeCollapse file tree

11 files changed

+370
-64
lines changed
Filter options
Expand file treeCollapse file tree

11 files changed

+370
-64
lines changed

‎benchmarks/benchmarks/bench_ufunc.py

Copy file name to clipboardExpand all lines: benchmarks/benchmarks/bench_ufunc.py
+4-4Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
'isinf', 'isnan', 'isnat', 'lcm', 'ldexp', 'left_shift', 'less',
1717
'less_equal', 'log', 'log10', 'log1p', 'log2', 'logaddexp',
1818
'logaddexp2', 'logical_and', 'logical_not', 'logical_or',
19-
'logical_xor', 'matmul', 'maximum', 'minimum', 'mod', 'modf',
20-
'multiply', 'negative', 'nextafter', 'not_equal', 'positive',
19+
'logical_xor', 'matmul', 'matvec', 'maximum', 'minimum', 'mod',
20+
'modf', 'multiply', 'negative', 'nextafter', 'not_equal', 'positive',
2121
'power', 'rad2deg', 'radians', 'reciprocal', 'remainder',
2222
'right_shift', 'rint', 'sign', 'signbit', 'sin',
2323
'sinh', 'spacing', 'sqrt', 'square', 'subtract', 'tan', 'tanh',
24-
'true_divide', 'trunc', 'vecdot']
24+
'true_divide', 'trunc', 'vecdot', 'vecmat']
2525
arrayfuncdisp = ['real', 'round']
2626

2727
for name in ufuncs:
@@ -597,7 +597,7 @@ def setup(self, dtype):
597597
N = 1000000
598598
self.a = np.random.randint(20, size=N).astype(dtype)
599599
self.b = np.random.randint(4, size=N).astype(dtype)
600-
600+
601601
def time_pow(self, dtype):
602602
np.power(self.a, self.b)
603603

+20Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
New functions for matrix-vector and vector-matrix products
2+
----------------------------------------------------------
3+
4+
Two new generalized ``ufunc``s were defined:
5+
6+
* `numpy.matvec` - matrix-vector product, treating the arguments as
7+
stacks of matrices and column vectors, respectively.
8+
9+
* `numpy.vecmat` - vector-matrix product, treating the arguments as
10+
stacks of column vectors and matrices, respectively. For complex
11+
vectors, the conjugate is taken.
12+
13+
These add to the existing `numpy.matmul` as well as to `numpy.vecdot`,
14+
which was added in numpy 2.0.
15+
16+
Note that `numpy.matmul` never takes a complex conjugate, also not
17+
when its left input is a vector, while both `numpy.vecdot` and
18+
`numpy.vecmat` do take the conjugate for complex vectors on the
19+
left-hand side (which are taken to be the ones that are transposed,
20+
following the physics convention).

‎doc/source/reference/routines.linalg.rst

Copy file name to clipboardExpand all lines: doc/source/reference/routines.linalg.rst
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ Matrix and vector products
6262
outer
6363
matmul
6464
linalg.matmul (Array API compatible location)
65+
matvec
66+
vecmat
6567
tensordot
6668
linalg.tensordot (Array API compatible location)
6769
einsum

‎numpy/__init__.py

Copy file name to clipboardExpand all lines: numpy/__init__.py
+18-18Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,10 @@
151151
left_shift, less, less_equal, lexsort, linspace, little_endian, log,
152152
log10, log1p, log2, logaddexp, logaddexp2, logical_and, logical_not,
153153
logical_or, logical_xor, logspace, long, longdouble, longlong, matmul,
154-
matrix_transpose, max, maximum, may_share_memory, mean, memmap, min,
155-
min_scalar_type, minimum, mod, modf, moveaxis, multiply, nan, ndarray,
156-
ndim, nditer, negative, nested_iters, newaxis, nextafter, nonzero,
157-
not_equal, number, object_, ones, ones_like, outer, partition,
154+
matvec, matrix_transpose, max, maximum, may_share_memory, mean, memmap,
155+
min, min_scalar_type, minimum, mod, modf, moveaxis, multiply, nan,
156+
ndarray, ndim, nditer, negative, nested_iters, newaxis, nextafter,
157+
nonzero, not_equal, number, object_, ones, ones_like, outer, partition,
158158
permute_dims, pi, positive, pow, power, printoptions, prod,
159159
promote_types, ptp, put, putmask, rad2deg, radians, ravel, recarray,
160160
reciprocal, record, remainder, repeat, require, reshape, resize,
@@ -165,11 +165,11 @@
165165
str_, subtract, sum, swapaxes, take, tan, tanh, tensordot,
166166
timedelta64, trace, transpose, true_divide, trunc, typecodes, ubyte,
167167
ufunc, uint, uint16, uint32, uint64, uint8, uintc, uintp, ulong,
168-
ulonglong, unsignedinteger, unstack, ushort, var, vdot, vecdot, void,
169-
vstack, where, zeros, zeros_like
168+
ulonglong, unsignedinteger, unstack, ushort, var, vdot, vecdot,
169+
vecmat, void, vstack, where, zeros, zeros_like
170170
)
171171

172-
# NOTE: It's still under discussion whether these aliases
172+
# NOTE: It's still under discussion whether these aliases
173173
# should be removed.
174174
for ta in ["float96", "float128", "complex192", "complex256"]:
175175
try:
@@ -184,21 +184,21 @@
184184
histogram, histogram_bin_edges, histogramdd
185185
)
186186
from .lib._nanfunctions_impl import (
187-
nanargmax, nanargmin, nancumprod, nancumsum, nanmax, nanmean,
187+
nanargmax, nanargmin, nancumprod, nancumsum, nanmax, nanmean,
188188
nanmedian, nanmin, nanpercentile, nanprod, nanquantile, nanstd,
189189
nansum, nanvar
190190
)
191191
from .lib._function_base_impl import (
192-
select, piecewise, trim_zeros, copy, iterable, percentile, diff,
192+
select, piecewise, trim_zeros, copy, iterable, percentile, diff,
193193
gradient, angle, unwrap, sort_complex, flip, rot90, extract, place,
194194
vectorize, asarray_chkfinite, average, bincount, digitize, cov,
195195
corrcoef, median, sinc, hamming, hanning, bartlett, blackman,
196196
kaiser, trapezoid, trapz, i0, meshgrid, delete, insert, append,
197197
interp, quantile
198198
)
199199
from .lib._twodim_base_impl import (
200-
diag, diagflat, eye, fliplr, flipud, tri, triu, tril, vander,
201-
histogram2d, mask_indices, tril_indices, tril_indices_from,
200+
diag, diagflat, eye, fliplr, flipud, tri, triu, tril, vander,
201+
histogram2d, mask_indices, tril_indices, tril_indices_from,
202202
triu_indices, triu_indices_from
203203
)
204204
from .lib._shape_base_impl import (
@@ -207,7 +207,7 @@
207207
take_along_axis, tile, vsplit
208208
)
209209
from .lib._type_check_impl import (
210-
iscomplexobj, isrealobj, imag, iscomplex, isreal, nan_to_num, real,
210+
iscomplexobj, isrealobj, imag, iscomplex, isreal, nan_to_num, real,
211211
real_if_close, typename, mintypecode, common_type
212212
)
213213
from .lib._arraysetops_impl import (
@@ -232,7 +232,7 @@
232232
)
233233
from .lib._index_tricks_impl import (
234234
diag_indices_from, diag_indices, fill_diagonal, ndindex, ndenumerate,
235-
ix_, c_, r_, s_, ogrid, mgrid, unravel_index, ravel_multi_index,
235+
ix_, c_, r_, s_, ogrid, mgrid, unravel_index, ravel_multi_index,
236236
index_exp
237237
)
238238

@@ -246,7 +246,7 @@
246246
# (experimental label) are not added here, because `from numpy import *`
247247
# must not raise any warnings - that's too disruptive.
248248
__numpy_submodules__ = {
249-
"linalg", "fft", "dtypes", "random", "polynomial", "ma",
249+
"linalg", "fft", "dtypes", "random", "polynomial", "ma",
250250
"exceptions", "lib", "ctypeslib", "testing", "typing",
251251
"f2py", "test", "rec", "char", "core", "strings",
252252
}
@@ -395,7 +395,7 @@ def __getattr__(attr):
395395

396396
if attr in __former_attrs__:
397397
raise AttributeError(__former_attrs__[attr], name=None)
398-
398+
399399
if attr in __expired_attributes__:
400400
raise AttributeError(
401401
f"`np.{attr}` was removed in the NumPy 2.0 release. "
@@ -419,7 +419,7 @@ def __dir__():
419419
globals().keys() | __numpy_submodules__
420420
)
421421
public_symbols -= {
422-
"matrixlib", "matlib", "tests", "conftest", "version",
422+
"matrixlib", "matlib", "tests", "conftest", "version",
423423
"compat", "distutils", "array_api"
424424
}
425425
return list(public_symbols)
@@ -493,7 +493,7 @@ def _mac_os_check():
493493
def hugepage_setup():
494494
"""
495495
We usually use madvise hugepages support, but on some old kernels it
496-
is slow and thus better avoided. Specifically kernel version 4.6
496+
is slow and thus better avoided. Specifically kernel version 4.6
497497
had a bug fix which probably fixed this:
498498
https://github.com/torvalds/linux/commit/7cf91a98e607c2f935dbcc177d70011e95b8faff
499499
"""
@@ -502,7 +502,7 @@ def hugepage_setup():
502502
# If there is an issue with parsing the kernel version,
503503
# set use_hugepage to 0. Usage of LooseVersion will handle
504504
# the kernel version parsing better, but avoided since it
505-
# will increase the import time.
505+
# will increase the import time.
506506
# See: #16679 for related discussion.
507507
try:
508508
use_hugepage = 1

‎numpy/__init__.pyi

Copy file name to clipboardExpand all lines: numpy/__init__.pyi
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3461,6 +3461,7 @@ logical_not: _UFunc_Nin1_Nout1[L['logical_not'], L[20], None]
34613461
logical_or: _UFunc_Nin2_Nout1[L['logical_or'], L[20], L[False]]
34623462
logical_xor: _UFunc_Nin2_Nout1[L['logical_xor'], L[19], L[False]]
34633463
matmul: _GUFunc_Nin2_Nout1[L['matmul'], L[19], None, L["(n?,k),(k,m?)->(n?,m?)"]]
3464+
matvec: _GUFunc_Nin2_Nout1[L['matvec'], L[19], None, L["(m,n),(n)->(m)"]]
34643465
maximum: _UFunc_Nin2_Nout1[L['maximum'], L[21], None]
34653466
minimum: _UFunc_Nin2_Nout1[L['minimum'], L[21], None]
34663467
mod: _UFunc_Nin2_Nout1[L['remainder'], L[16], None]
@@ -3490,6 +3491,7 @@ tanh: _UFunc_Nin1_Nout1[L['tanh'], L[8], None]
34903491
true_divide: _UFunc_Nin2_Nout1[L['true_divide'], L[11], None]
34913492
trunc: _UFunc_Nin1_Nout1[L['trunc'], L[7], None]
34923493
vecdot: _GUFunc_Nin2_Nout1[L['vecdot'], L[19], None, L["(n),(n)->()"]]
3494+
vecmat: _GUFunc_Nin2_Nout1[L['vecmat'], L[19], None, L["(n),(n,m)->(m)"]]
34933495

34943496
abs = absolute
34953497
acos = arccos

‎numpy/_core/code_generators/generate_umath.py

Copy file name to clipboardExpand all lines: numpy/_core/code_generators/generate_umath.py
+16Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,22 @@ def english_upper(s):
11541154
TD(O),
11551155
signature='(n),(n)->()',
11561156
),
1157+
'matvec':
1158+
Ufunc(2, 1, None,
1159+
docstrings.get('numpy._core.umath.matvec'),
1160+
"PyUFunc_SimpleUniformOperationTypeResolver",
1161+
TD(notimes_or_obj),
1162+
TD(O),
1163+
signature='(m,n),(n)->(m)',
1164+
),
1165+
'vecmat':
1166+
Ufunc(2, 1, None,
1167+
docstrings.get('numpy._core.umath.vecmat'),
1168+
"PyUFunc_SimpleUniformOperationTypeResolver",
1169+
TD(notimes_or_obj),
1170+
TD(O),
1171+
signature='(n),(n,m)->(m)',
1172+
),
11571173
'str_len':
11581174
Ufunc(1, 1, Zero,
11591175
docstrings.get('numpy._core.umath.str_len'),

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.