Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

MAINT: use an atomic load/store and a mutex to initialize the argparse and runtime import caches #26780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions 1 numpy/_core/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,7 @@ src_multiarray_umath_common = [
'src/common/mem_overlap.c',
'src/common/npy_argparse.c',
'src/common/npy_hashtable.c',
'src/common/npy_import.c',
'src/common/npy_longdouble.c',
'src/common/ucsnarrow.c',
'src/common/ufunc_override.c',
Expand Down
34 changes: 25 additions & 9 deletions 34 numpy/_core/src/common/npy_argparse.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,22 @@
#include "numpy/ndarraytypes.h"
#include "numpy/npy_2_compat.h"
#include "npy_argparse.h"

#include "npy_atomic.h"
#include "npy_import.h"

#include "arrayfunction_override.h"

static PyThread_type_lock argparse_mutex;

NPY_NO_EXPORT int
init_argparse_mutex(void) {
argparse_mutex = PyThread_allocate_lock();
if (argparse_mutex == NULL) {
PyErr_NoMemory();
return -1;
}
return 0;
}

/**
* Small wrapper converting to array just like CPython does.
Expand Down Expand Up @@ -274,15 +285,20 @@ _npy_parse_arguments(const char *funcname,
/* ... is NULL, NULL, NULL terminated: name, converter, value */
...)
{
if (NPY_UNLIKELY(cache->npositional == -1)) {
va_list va;
va_start(va, kwnames);

int res = initialize_keywords(funcname, cache, va);
va_end(va);
if (res < 0) {
return -1;
if (!npy_atomic_load_uint8(&cache->initialized)) {
PyThread_acquire_lock(argparse_mutex, WAIT_LOCK);
if (!npy_atomic_load_uint8(&cache->initialized)) {
va_list va;
va_start(va, kwnames);
int res = initialize_keywords(funcname, cache, va);
va_end(va);
if (res < 0) {
PyThread_release_lock(argparse_mutex);
return -1;
}
npy_atomic_store_uint8(&cache->initialized, 1);
}
PyThread_release_lock(argparse_mutex);
}

if (NPY_UNLIKELY(len_args > cache->npositional)) {
Expand Down
5 changes: 3 additions & 2 deletions 5 numpy/_core/src/common/npy_argparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,25 @@
NPY_NO_EXPORT int
PyArray_PythonPyIntFromInt(PyObject *obj, int *value);


#define _NPY_MAX_KWARGS 15

typedef struct {
int npositional;
int nargs;
int npositional_only;
int nrequired;
npy_uint8 initialized;
/* Null terminated list of keyword argument name strings */
PyObject *kw_strings[_NPY_MAX_KWARGS+1];
} _NpyArgParserCache;

NPY_NO_EXPORT int init_argparse_mutex(void);

/*
* The sole purpose of this macro is to hide the argument parsing cache.
ngoldbaum marked this conversation as resolved.
Show resolved Hide resolved
* Since this cache must be static, this also removes a source of error.
*/
#define NPY_PREPARE_ARGPARSER static _NpyArgParserCache __argparse_cache = {-1}
#define NPY_PREPARE_ARGPARSER static _NpyArgParserCache __argparse_cache;

/**
* Macro to help with argument parsing.
Expand Down
99 changes: 99 additions & 0 deletions 99 numpy/_core/src/common/npy_atomic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Provides wrappers around C11 standard library atomics and MSVC intrinsics
* to provide basic atomic load and store functionality. This is based on
* code in CPython's pyatomic.h, pyatomic_std.h, and pyatomic_msc.h
ngoldbaum marked this conversation as resolved.
Show resolved Hide resolved
*/

#ifndef NUMPY_CORE_SRC_COMMON_NPY_ATOMIC_H_
#define NUMPY_CORE_SRC_COMMON_NPY_ATOMIC_H_

#include "numpy/npy_common.h"

#if __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_ATOMICS__)
// TODO: support C++ atomics as well if this header is ever needed in C++
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code to do this is in the CPython header this is cribbed from. It would be dead code if I included it.

#include <stdatomic.h>
#include <stdint.h>
#define STDC_ATOMICS
#elif _MSC_VER
#include <intrin.h>
#define MSC_ATOMICS
#if !defined(_M_X64) && !defined(_M_IX86) && !defined(_M_ARM64)
#error "Unsupported MSVC build configuration, neither x86 or ARM"
#endif
#elif defined(__GNUC__) && (__GNUC__ > 4)
#define GCC_ATOMICS
#elif defined(__clang__)
#if __has_builtin(__atomic_load)
#define GCC_ATOMICS
#endif
#else
#error "no supported atomic implementation for this platform/compiler"
#endif


static inline npy_uint8
npy_atomic_load_uint8(const npy_uint8 *obj) {
#ifdef STDC_ATOMICS
return (npy_uint8)atomic_load((const _Atomic(uint8_t)*)obj);
#elif defined(MSC_ATOMICS)
#if defined(_M_X64) || defined(_M_IX86)
return *(volatile npy_uint8 *)obj;
#else // defined(_M_ARM64)
return (npy_uint8)__ldar8((unsigned __int8 volatile *)obj);
#endif
#elif defined(GCC_ATOMICS)
return __atomic_load_n(obj, __ATOMIC_SEQ_CST);
#endif
}

static inline void*
npy_atomic_load_ptr(const void *obj) {
#ifdef STDC_ATOMICS
return atomic_load((const _Atomic(void *)*)obj);
#elif defined(MSC_ATOMICS)
#if SIZEOF_VOID_P == 8
#if defined(_M_X64) || defined(_M_IX86)
return *(volatile uint64_t *)obj;
#elif defined(_M_ARM64)
return (uint64_t)__ldar64((unsigned __int64 volatile *)obj);
#endif
#else
#if defined(_M_X64) || defined(_M_IX86)
return *(volatile uint32_t *)obj;
#elif defined(_M_ARM64)
return (uint32_t)__ldar32((unsigned __int32 volatile *)obj);
#endif
#endif
#elif defined(GCC_ATOMICS)
return (void *)__atomic_load_n((void * const *)obj, __ATOMIC_SEQ_CST);
#endif
}

static inline void
npy_atomic_store_uint8(npy_uint8 *obj, npy_uint8 value) {
#ifdef STDC_ATOMICS
atomic_store((_Atomic(uint8_t)*)obj, value);
#elif defined(MSC_ATOMICS)
_InterlockedExchange8((volatile char *)obj, (char)value);
#elif defined(GCC_ATOMICS)
__atomic_store_n(obj, value, __ATOMIC_SEQ_CST);
#endif
}

static inline void
npy_atomic_store_ptr(void *obj, void *value)
{
#ifdef STDC_ATOMICS
atomic_store((_Atomic(void *)*)obj, value);
#elif defined(MSC_ATOMICS)
_InterlockedExchangePointer((void * volatile *)obj, (void *)value);
#elif defined(GCC_ATOMICS)
__atomic_store_n((void **)obj, value, __ATOMIC_SEQ_CST);
#endif
}

#undef MSC_ATOMICS
#undef STDC_ATOMICS
#undef GCC_ATOMICS

#endif // NUMPY_CORE_SRC_COMMON_NPY_NPY_ATOMIC_H_
11 changes: 6 additions & 5 deletions 11 numpy/_core/src/common/npy_ctypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ npy_ctypes_check(PyTypeObject *obj)
PyObject *ret_obj;
int ret;

npy_cache_import("numpy._core._internal", "npy_ctypes_check",
&npy_thread_unsafe_state.npy_ctypes_check);
if (npy_thread_unsafe_state.npy_ctypes_check == NULL) {

if (npy_cache_import_runtime(
"numpy._core._internal", "npy_ctypes_check",
&npy_runtime_imports.npy_ctypes_check) == -1) {
goto fail;
}

ret_obj = PyObject_CallFunctionObjArgs(npy_thread_unsafe_state.npy_ctypes_check,
(PyObject *)obj, NULL);
ret_obj = PyObject_CallFunctionObjArgs(
npy_runtime_imports.npy_ctypes_check, (PyObject *)obj, NULL);
if (ret_obj == NULL) {
goto fail;
}
Expand Down
19 changes: 19 additions & 0 deletions 19 numpy/_core/src/common/npy_import.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
#define _MULTIARRAYMODULE

#include "numpy/ndarraytypes.h"
#include "npy_import.h"
#include "npy_atomic.h"


NPY_VISIBILITY_HIDDEN npy_runtime_imports_struct npy_runtime_imports;

NPY_NO_EXPORT int
init_import_mutex(void) {
npy_runtime_imports.import_mutex = PyThread_allocate_lock();
if (npy_runtime_imports.import_mutex == NULL) {
PyErr_NoMemory();
return -1;
}
return 0;
}
90 changes: 80 additions & 10 deletions 90 numpy/_core/src/common/npy_import.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,70 @@

#include <Python.h>

/*! \brief Fetch and cache Python function.
#include "numpy/npy_common.h"
#include "npy_atomic.h"

/*
* Cached references to objects obtained via an import. All of these are
* can be initialized at any time by npy_cache_import_runtime.
*/
typedef struct npy_runtime_imports_struct {
PyThread_type_lock import_mutex;
PyObject *_add_dtype_helper;
PyObject *_all;
PyObject *_amax;
PyObject *_amin;
PyObject *_any;
PyObject *array_function_errmsg_formatter;
PyObject *array_ufunc_errmsg_formatter;
PyObject *_clip;
PyObject *_commastring;
PyObject *_convert_to_stringdtype_kwargs;
PyObject *_default_array_repr;
PyObject *_default_array_str;
PyObject *_dump;
PyObject *_dumps;
PyObject *_getfield_is_safe;
PyObject *internal_gcd_func;
PyObject *_mean;
PyObject *NO_NEP50_WARNING;
PyObject *npy_ctypes_check;
PyObject *numpy_matrix;
PyObject *_prod;
PyObject *_promote_fields;
PyObject *_std;
PyObject *_sum;
PyObject *_ufunc_doc_signature_formatter;
PyObject *_var;
PyObject *_view_is_safe;
PyObject *_void_scalar_to_string;
} npy_runtime_imports_struct;

NPY_VISIBILITY_HIDDEN extern npy_runtime_imports_struct npy_runtime_imports;

/*! \brief Import a Python object.

* This function imports the Python function specified by
* \a module and \a function, increments its reference count, and returns
* the result. On error, returns NULL.
*
* @param module Absolute module name.
* @param attr module attribute to cache.
*/
static inline PyObject*
npy_import(const char *module, const char *attr)
{
PyObject *ret = NULL;
PyObject *mod = PyImport_ImportModule(module);

if (mod != NULL) {
ret = PyObject_GetAttrString(mod, attr);
Py_DECREF(mod);
}
return ret;
}

/*! \brief Fetch and cache Python object at runtime.
*
* Import a Python function and cache it for use. The function checks if
* cache is NULL, and if not NULL imports the Python function specified by
Expand All @@ -16,17 +79,24 @@
* @param attr module attribute to cache.
* @param cache Storage location for imported function.
*/
static inline void
npy_cache_import(const char *module, const char *attr, PyObject **cache)
{
if (NPY_UNLIKELY(*cache == NULL)) {
PyObject *mod = PyImport_ImportModule(module);

if (mod != NULL) {
*cache = PyObject_GetAttrString(mod, attr);
Py_DECREF(mod);
static inline int
npy_cache_import_runtime(const char *module, const char *attr, PyObject **obj) {
if (!npy_atomic_load_ptr(obj)) {
PyObject* value = npy_import(module, attr);
if (value == NULL) {
return -1;
}
PyThread_acquire_lock(npy_runtime_imports.import_mutex, WAIT_LOCK);
if (!npy_atomic_load_ptr(obj)) {
npy_atomic_store_ptr(obj, Py_NewRef(value));
}
PyThread_release_lock(npy_runtime_imports.import_mutex);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the release before the error check, or we will dead-lock on error!

Copy link

@colesbury colesbury Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Also, the obj == NULL check doesn't make sense -- it would need to be *obj == NULL to check the result.

But I think it would be better not to hold the lock around the npy_import:

  • Imports in general can run arbitrary code, which may reentrantly trigger npy_cache_import_runtime.
  • Imports (i.e., PyImport_ImportModule) are already thread-safe internally

So I'd suggest writing this as:

if (!npy_atomic_load_ptr(obj)) {
    PyObject *value = npy_import(module, attr);
    if (value == NULL) {
        return -1;
    }
    PyThread_acquire_lock(npy_runtime_imports.import_mutex, WAIT_LOCK);
    if (!npy_atomic_load_ptr(obj)) {
        npy_atomic_store_ptr(obj, Py_NewRef(value));
    }
    PyThread_release_lock(npy_runtime_imports.import_mutex);
    Py_DECREF(value);
}

Or you can get rid of the lock if you implement compare-exchange:

if (!npy_atomic_load_ptr(obj)) {
    PyObject *value = npy_import(module, attr);
    if (value == NULL) {
        return -1;
    }
    PyObject *exepected = NULL;
    if (!npy_atomic_compare_exchange_ptr(obj, &expected, value)) {
        Py_DECREF(value);
    }
}

Copy link
Member Author

@ngoldbaum ngoldbaum Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I kept the lock since the next iteration will use PyMutex and I'd like to make sure this gets in for NumPy 2.1. Thankfully deadsnakes was updated over the weekend so we can test against a version with a public PyMutex now.

Py_DECREF(value);
}
return 0;
}

NPY_NO_EXPORT int
init_import_mutex(void);

#endif /* NUMPY_CORE_SRC_COMMON_NPY_IMPORT_H_ */
28 changes: 28 additions & 0 deletions 28 numpy/_core/src/multiarray/_multiarray_tests.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@ argparse_example_function(PyObject *NPY_UNUSED(mod),
Py_RETURN_NONE;
}

/*
* Tests that argparse cache creation is thread-safe. *must* be called only
* by the python-level test_thread_safe_argparse_cache function, otherwise
* the cache might be created before the test to make sure cache creation is
* thread-safe runs
*/
static PyObject *
threaded_argparse_example_function(PyObject *NPY_UNUSED(mod),
PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames)
{
NPY_PREPARE_ARGPARSER;
int arg1;
PyObject *arg2;
if (npy_parse_arguments("thread_func", args, len_args, kwnames,
"$arg1", &PyArray_PythonPyIntFromInt, &arg1,
"$arg2", NULL, &arg2,
NULL, NULL, NULL) < 0) {
return NULL;
}
Py_RETURN_NONE;
}

/* test PyArray_IsPythonScalar, before including private py3 compat header */
static PyObject *
IsPythonScalar(PyObject * dummy, PyObject *args)
Expand Down Expand Up @@ -2205,6 +2227,9 @@ static PyMethodDef Multiarray_TestsMethods[] = {
{"argparse_example_function",
(PyCFunction)argparse_example_function,
METH_KEYWORDS | METH_FASTCALL, NULL},
{"threaded_argparse_example_function",
(PyCFunction)threaded_argparse_example_function,
METH_KEYWORDS | METH_FASTCALL, NULL},
{"IsPythonScalar",
IsPythonScalar,
METH_VARARGS, NULL},
Expand Down Expand Up @@ -2407,6 +2432,9 @@ PyMODINIT_FUNC PyInit__multiarray_tests(void)
return m;
}
import_array();
if (init_argparse_mutex() < 0) {
return NULL;
}
if (PyErr_Occurred()) {
PyErr_SetString(PyExc_RuntimeError,
"cannot load _multiarray_tests module.");
Expand Down
Loading
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.