Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 9312d5e

Browse filesBrowse files
authored
Merge pull request #27164 from seberg/einsum-argparse
MAINT: use npy_argparse for einsum
2 parents 32a2304 + bbf0ff4 commit 9312d5e
Copy full SHA for 9312d5e

File tree

1 file changed

+51
-96
lines changed
Filter options

1 file changed

+51
-96
lines changed

‎numpy/_core/src/multiarray/multiarraymodule.c

Copy file name to clipboardExpand all lines: numpy/_core/src/multiarray/multiarraymodule.c
+51-96Lines changed: 51 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -2704,13 +2704,13 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *const *args, Py_ssize_t len_ar
27042704
}
27052705

27062706
static int
2707-
einsum_sub_op_from_str(PyObject *args, PyObject **str_obj, char **subscripts,
2708-
PyArrayObject **op)
2707+
einsum_sub_op_from_str(
2708+
Py_ssize_t nargs, PyObject *const *args,
2709+
PyObject **str_obj, char **subscripts, PyArrayObject **op)
27092710
{
2710-
int i, nop;
2711+
Py_ssize_t nop = nargs - 1;
27112712
PyObject *subscripts_str;
27122713

2713-
nop = PyTuple_GET_SIZE(args) - 1;
27142714
if (nop <= 0) {
27152715
PyErr_SetString(PyExc_ValueError,
27162716
"must specify the einstein sum subscripts string "
@@ -2723,7 +2723,7 @@ einsum_sub_op_from_str(PyObject *args, PyObject **str_obj, char **subscripts,
27232723
}
27242724

27252725
/* Get the subscripts string */
2726-
subscripts_str = PyTuple_GET_ITEM(args, 0);
2726+
subscripts_str = args[0];
27272727
if (PyUnicode_Check(subscripts_str)) {
27282728
*str_obj = PyUnicode_AsASCIIString(subscripts_str);
27292729
if (*str_obj == NULL) {
@@ -2740,15 +2740,13 @@ einsum_sub_op_from_str(PyObject *args, PyObject **str_obj, char **subscripts,
27402740
}
27412741

27422742
/* Set the operands to NULL */
2743-
for (i = 0; i < nop; ++i) {
2743+
for (Py_ssize_t i = 0; i < nop; ++i) {
27442744
op[i] = NULL;
27452745
}
27462746

27472747
/* Get the operands */
2748-
for (i = 0; i < nop; ++i) {
2749-
PyObject *obj = PyTuple_GET_ITEM(args, i+1);
2750-
2751-
op[i] = (PyArrayObject *)PyArray_FROM_OF(obj, NPY_ARRAY_ENSUREARRAY);
2748+
for (Py_ssize_t i = 0; i < nop; ++i) {
2749+
op[i] = (PyArrayObject *)PyArray_FROM_OF(args[i+1], NPY_ARRAY_ENSUREARRAY);
27522750
if (op[i] == NULL) {
27532751
goto fail;
27542752
}
@@ -2757,7 +2755,7 @@ einsum_sub_op_from_str(PyObject *args, PyObject **str_obj, char **subscripts,
27572755
return nop;
27582756

27592757
fail:
2760-
for (i = 0; i < nop; ++i) {
2758+
for (Py_ssize_t i = 0; i < nop; ++i) {
27612759
Py_XDECREF(op[i]);
27622760
op[i] = NULL;
27632761
}
@@ -2861,13 +2859,12 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize)
28612859
* Returns -1 on error, number of operands placed in op otherwise.
28622860
*/
28632861
static int
2864-
einsum_sub_op_from_lists(PyObject *args,
2865-
char *subscripts, int subsize, PyArrayObject **op)
2862+
einsum_sub_op_from_lists(Py_ssize_t nargs, PyObject *const *args,
2863+
char *subscripts, int subsize, PyArrayObject **op)
28662864
{
28672865
int subindex = 0;
2868-
npy_intp i, nop;
28692866

2870-
nop = PyTuple_Size(args)/2;
2867+
Py_ssize_t nop = nargs / 2;
28712868

28722869
if (nop == 0) {
28732870
PyErr_SetString(PyExc_ValueError, "must provide at least an "
@@ -2880,15 +2877,12 @@ einsum_sub_op_from_lists(PyObject *args,
28802877
}
28812878

28822879
/* Set the operands to NULL */
2883-
for (i = 0; i < nop; ++i) {
2880+
for (Py_ssize_t i = 0; i < nop; ++i) {
28842881
op[i] = NULL;
28852882
}
28862883

28872884
/* Get the operands and build the subscript string */
2888-
for (i = 0; i < nop; ++i) {
2889-
PyObject *obj = PyTuple_GET_ITEM(args, 2*i);
2890-
int n;
2891-
2885+
for (Py_ssize_t i = 0; i < nop; ++i) {
28922886
/* Comma between the subscripts for each operand */
28932887
if (i != 0) {
28942888
subscripts[subindex++] = ',';
@@ -2899,25 +2893,21 @@ einsum_sub_op_from_lists(PyObject *args,
28992893
}
29002894
}
29012895

2902-
op[i] = (PyArrayObject *)PyArray_FROM_OF(obj, NPY_ARRAY_ENSUREARRAY);
2896+
op[i] = (PyArrayObject *)PyArray_FROM_OF(args[2*i], NPY_ARRAY_ENSUREARRAY);
29032897
if (op[i] == NULL) {
29042898
goto fail;
29052899
}
29062900

2907-
obj = PyTuple_GET_ITEM(args, 2*i+1);
2908-
n = einsum_list_to_subscripts(obj, subscripts+subindex,
2909-
subsize-subindex);
2901+
int n = einsum_list_to_subscripts(
2902+
args[2*i + 1], subscripts+subindex, subsize-subindex);
29102903
if (n < 0) {
29112904
goto fail;
29122905
}
29132906
subindex += n;
29142907
}
29152908

29162909
/* Add the '->' to the string if provided */
2917-
if (PyTuple_Size(args) == 2*nop+1) {
2918-
PyObject *obj;
2919-
int n;
2920-
2910+
if (nargs == 2*nop+1) {
29212911
if (subindex + 2 >= subsize) {
29222912
PyErr_SetString(PyExc_ValueError,
29232913
"subscripts list is too long");
@@ -2926,9 +2916,8 @@ einsum_sub_op_from_lists(PyObject *args,
29262916
subscripts[subindex++] = '-';
29272917
subscripts[subindex++] = '>';
29282918

2929-
obj = PyTuple_GET_ITEM(args, 2*nop);
2930-
n = einsum_list_to_subscripts(obj, subscripts+subindex,
2931-
subsize-subindex);
2919+
int n = einsum_list_to_subscripts(
2920+
args[2*nop], subscripts+subindex, subsize-subindex);
29322921
if (n < 0) {
29332922
goto fail;
29342923
}
@@ -2941,7 +2930,7 @@ einsum_sub_op_from_lists(PyObject *args,
29412930
return nop;
29422931

29432932
fail:
2944-
for (i = 0; i < nop; ++i) {
2933+
for (Py_ssize_t i = 0; i < nop; ++i) {
29452934
Py_XDECREF(op[i]);
29462935
op[i] = NULL;
29472936
}
@@ -2950,108 +2939,74 @@ einsum_sub_op_from_lists(PyObject *args,
29502939
}
29512940

29522941
static PyObject *
2953-
array_einsum(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds)
2942+
array_einsum(PyObject *NPY_UNUSED(dummy),
2943+
PyObject *const *args, Py_ssize_t nargsf, PyObject *kwnames)
29542944
{
29552945
char *subscripts = NULL, subscripts_buffer[256];
29562946
PyObject *str_obj = NULL, *str_key_obj = NULL;
2957-
PyObject *arg0;
2958-
int i, nop;
2947+
int nop;
29592948
PyArrayObject *op[NPY_MAXARGS];
29602949
NPY_ORDER order = NPY_KEEPORDER;
29612950
NPY_CASTING casting = NPY_SAFE_CASTING;
2951+
PyObject *out_obj = NULL;
29622952
PyArrayObject *out = NULL;
29632953
PyArray_Descr *dtype = NULL;
29642954
PyObject *ret = NULL;
2955+
NPY_PREPARE_ARGPARSER;
2956+
2957+
Py_ssize_t nargs = PyVectorcall_NARGS(nargsf);
29652958

2966-
if (PyTuple_GET_SIZE(args) < 1) {
2959+
if (nargs < 1) {
29672960
PyErr_SetString(PyExc_ValueError,
29682961
"must specify the einstein sum subscripts string "
29692962
"and at least one operand, or at least one operand "
29702963
"and its corresponding subscripts list");
29712964
return NULL;
29722965
}
2973-
arg0 = PyTuple_GET_ITEM(args, 0);
29742966

29752967
/* einsum('i,j', a, b), einsum('i,j->ij', a, b) */
2976-
if (PyBytes_Check(arg0) || PyUnicode_Check(arg0)) {
2977-
nop = einsum_sub_op_from_str(args, &str_obj, &subscripts, op);
2968+
if (PyBytes_Check(args[0]) || PyUnicode_Check(args[0])) {
2969+
nop = einsum_sub_op_from_str(nargs, args, &str_obj, &subscripts, op);
29782970
}
29792971
/* einsum(a, [0], b, [1]), einsum(a, [0], b, [1], [0,1]) */
29802972
else {
2981-
nop = einsum_sub_op_from_lists(args, subscripts_buffer,
2982-
sizeof(subscripts_buffer), op);
2973+
nop = einsum_sub_op_from_lists(nargs, args, subscripts_buffer,
2974+
sizeof(subscripts_buffer), op);
29832975
subscripts = subscripts_buffer;
29842976
}
29852977
if (nop <= 0) {
29862978
goto finish;
29872979
}
29882980

29892981
/* Get the keyword arguments */
2990-
if (kwds != NULL) {
2991-
PyObject *key, *value;
2992-
Py_ssize_t pos = 0;
2993-
while (PyDict_Next(kwds, &pos, &key, &value)) {
2994-
char *str = NULL;
2995-
2996-
Py_XDECREF(str_key_obj);
2997-
str_key_obj = PyUnicode_AsASCIIString(key);
2998-
if (str_key_obj != NULL) {
2999-
key = str_key_obj;
3000-
}
3001-
3002-
str = PyBytes_AsString(key);
3003-
3004-
if (str == NULL) {
3005-
PyErr_Clear();
3006-
PyErr_SetString(PyExc_TypeError, "invalid keyword");
3007-
goto finish;
3008-
}
3009-
3010-
if (strcmp(str,"out") == 0) {
3011-
if (PyArray_Check(value)) {
3012-
out = (PyArrayObject *)value;
3013-
}
3014-
else {
3015-
PyErr_SetString(PyExc_TypeError,
3016-
"keyword parameter out must be an "
3017-
"array for einsum");
3018-
goto finish;
3019-
}
3020-
}
3021-
else if (strcmp(str,"order") == 0) {
3022-
if (!PyArray_OrderConverter(value, &order)) {
3023-
goto finish;
3024-
}
3025-
}
3026-
else if (strcmp(str,"casting") == 0) {
3027-
if (!PyArray_CastingConverter(value, &casting)) {
3028-
goto finish;
3029-
}
3030-
}
3031-
else if (strcmp(str,"dtype") == 0) {
3032-
if (!PyArray_DescrConverter2(value, &dtype)) {
3033-
goto finish;
3034-
}
3035-
}
3036-
else {
3037-
PyErr_Format(PyExc_TypeError,
3038-
"'%s' is an invalid keyword for einsum",
3039-
str);
3040-
goto finish;
3041-
}
2982+
if (kwnames != NULL) {
2983+
if (npy_parse_arguments("einsum", args+nargs, 0, kwnames,
2984+
"$out", NULL, &out_obj,
2985+
"$order", &PyArray_OrderConverter, &order,
2986+
"$casting", &PyArray_CastingConverter, &casting,
2987+
"$dtype", &PyArray_DescrConverter2, &dtype,
2988+
NULL, NULL, NULL) < 0) {
2989+
goto finish;
30422990
}
2991+
if (out_obj != NULL && !PyArray_Check(out_obj)) {
2992+
PyErr_SetString(PyExc_TypeError,
2993+
"keyword parameter out must be an "
2994+
"array for einsum");
2995+
goto finish;
2996+
}
2997+
out = (PyArrayObject *)out_obj;
30432998
}
30442999

30453000
ret = (PyObject *)PyArray_EinsteinSum(subscripts, nop, op, dtype,
3046-
order, casting, out);
3001+
order, casting, out);
30473002

30483003
/* If no output was supplied, possibly convert to a scalar */
30493004
if (ret != NULL && out == NULL) {
30503005
ret = PyArray_Return((PyArrayObject *)ret);
30513006
}
30523007

30533008
finish:
3054-
for (i = 0; i < nop; ++i) {
3009+
for (Py_ssize_t i = 0; i < nop; ++i) {
30553010
Py_XDECREF(op[i]);
30563011
}
30573012
Py_XDECREF(dtype);
@@ -4518,7 +4473,7 @@ static struct PyMethodDef array_module_methods[] = {
45184473
METH_FASTCALL, NULL},
45194474
{"c_einsum",
45204475
(PyCFunction)array_einsum,
4521-
METH_VARARGS|METH_KEYWORDS, NULL},
4476+
METH_FASTCALL|METH_KEYWORDS, NULL},
45224477
{"correlate",
45234478
(PyCFunction)array_correlate,
45244479
METH_FASTCALL | METH_KEYWORDS, NULL},

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.