From 03cee3f9a9db3afa27ca48172826838f9e8abc4c Mon Sep 17 00:00:00 2001 From: Emma Smith Date: Thu, 22 May 2025 20:30:10 -0700 Subject: [PATCH] gh-133885: Use locks instead of critical sections for _zstd (gh-134289) Move from using critical sections to locks for the (de)compression methods. Since the methods allow other threads to run, we should use a lock rather than a critical section. (cherry picked from commit 8dbc11971974a725dc8a11c0dc65d8f6fcb4d902) Co-authored-by: Emma Smith --- Lib/test/test_zstd.py | 56 +++++++++- Modules/_zstd/clinic/decompressor.c.h | 11 +- Modules/_zstd/clinic/zstddict.c.h | 27 +---- Modules/_zstd/compressor.c | 151 ++++++++++++++------------ Modules/_zstd/decompressor.c | 150 +++++++++++++------------ Modules/_zstd/zstddict.c | 13 ++- Modules/_zstd/zstddict.h | 3 + 7 files changed, 229 insertions(+), 182 deletions(-) diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py index 53ca592ea38828..084f8f24fc009c 100644 --- a/Lib/test/test_zstd.py +++ b/Lib/test/test_zstd.py @@ -2430,10 +2430,8 @@ def test_buffer_protocol(self): self.assertEqual(f.write(arr), LENGTH) self.assertEqual(f.tell(), LENGTH) -@unittest.skip("it fails for now, see gh-133885") class FreeThreadingMethodTests(unittest.TestCase): - @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled') @threading_helper.reap_threads @threading_helper.requires_working_threading() def test_compress_locking(self): @@ -2470,7 +2468,6 @@ def run_method(method, input_data, output_data): actual = b''.join(output) + rest2 self.assertEqual(expected, actual) - @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled') @threading_helper.reap_threads @threading_helper.requires_working_threading() def test_decompress_locking(self): @@ -2506,6 +2503,59 @@ def run_method(method, input_data, output_data): actual = b''.join(output) self.assertEqual(expected, actual) + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_compress_shared_dict(self): + num_threads = 8 + + def run_method(b): + level = threading.get_ident() % 4 + # sync threads to increase chance of contention on + # capsule storing dictionary levels + b.wait() + ZstdCompressor(level=level, + zstd_dict=TRAINED_DICT.as_digested_dict) + b.wait() + ZstdCompressor(level=level, + zstd_dict=TRAINED_DICT.as_undigested_dict) + b.wait() + ZstdCompressor(level=level, + zstd_dict=TRAINED_DICT.as_prefix) + threads = [] + + b = threading.Barrier(num_threads) + for i in range(num_threads): + thread = threading.Thread(target=run_method, args=(b,)) + + threads.append(thread) + + with threading_helper.start_threads(threads): + pass + + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_decompress_shared_dict(self): + num_threads = 8 + + def run_method(b): + # sync threads to increase chance of contention on + # decompression dictionary + b.wait() + ZstdDecompressor(zstd_dict=TRAINED_DICT.as_digested_dict) + b.wait() + ZstdDecompressor(zstd_dict=TRAINED_DICT.as_undigested_dict) + b.wait() + ZstdDecompressor(zstd_dict=TRAINED_DICT.as_prefix) + threads = [] + + b = threading.Barrier(num_threads) + for i in range(num_threads): + thread = threading.Thread(target=run_method, args=(b,)) + + threads.append(thread) + + with threading_helper.start_threads(threads): + pass if __name__ == "__main__": diff --git a/Modules/_zstd/clinic/decompressor.c.h b/Modules/_zstd/clinic/decompressor.c.h index 4ecb19e9bde6ed..c6fdae74ab0447 100644 --- a/Modules/_zstd/clinic/decompressor.c.h +++ b/Modules/_zstd/clinic/decompressor.c.h @@ -7,7 +7,6 @@ preserve # include "pycore_runtime.h" // _Py_ID() #endif #include "pycore_abstract.h" // _PyNumber_Index() -#include "pycore_critical_section.h"// Py_BEGIN_CRITICAL_SECTION() #include "pycore_modsupport.h" // _PyArg_UnpackKeywords() PyDoc_STRVAR(_zstd_ZstdDecompressor_new__doc__, @@ -114,13 +113,7 @@ _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self); static PyObject * _zstd_ZstdDecompressor_unused_data_get(PyObject *self, void *Py_UNUSED(context)) { - PyObject *return_value = NULL; - - Py_BEGIN_CRITICAL_SECTION(self); - return_value = _zstd_ZstdDecompressor_unused_data_get_impl((ZstdDecompressor *)self); - Py_END_CRITICAL_SECTION(); - - return return_value; + return _zstd_ZstdDecompressor_unused_data_get_impl((ZstdDecompressor *)self); } PyDoc_STRVAR(_zstd_ZstdDecompressor_decompress__doc__, @@ -227,4 +220,4 @@ _zstd_ZstdDecompressor_decompress(PyObject *self, PyObject *const *args, Py_ssiz return return_value; } -/*[clinic end generated code: output=7a4d278f9244e684 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=30c12ef047027ede input=a9049054013a1b77]*/ diff --git a/Modules/_zstd/clinic/zstddict.c.h b/Modules/_zstd/clinic/zstddict.c.h index 34e0e4b3ecfe72..aaa29e491bc1bb 100644 --- a/Modules/_zstd/clinic/zstddict.c.h +++ b/Modules/_zstd/clinic/zstddict.c.h @@ -6,7 +6,6 @@ preserve # include "pycore_gc.h" // PyGC_Head # include "pycore_runtime.h" // _Py_ID() #endif -#include "pycore_critical_section.h"// Py_BEGIN_CRITICAL_SECTION() #include "pycore_modsupport.h" // _PyArg_UnpackKeywords() PyDoc_STRVAR(_zstd_ZstdDict_new__doc__, @@ -118,13 +117,7 @@ _zstd_ZstdDict_as_digested_dict_get_impl(ZstdDict *self); static PyObject * _zstd_ZstdDict_as_digested_dict_get(PyObject *self, void *Py_UNUSED(context)) { - PyObject *return_value = NULL; - - Py_BEGIN_CRITICAL_SECTION(self); - return_value = _zstd_ZstdDict_as_digested_dict_get_impl((ZstdDict *)self); - Py_END_CRITICAL_SECTION(); - - return return_value; + return _zstd_ZstdDict_as_digested_dict_get_impl((ZstdDict *)self); } PyDoc_STRVAR(_zstd_ZstdDict_as_undigested_dict__doc__, @@ -156,13 +149,7 @@ _zstd_ZstdDict_as_undigested_dict_get_impl(ZstdDict *self); static PyObject * _zstd_ZstdDict_as_undigested_dict_get(PyObject *self, void *Py_UNUSED(context)) { - PyObject *return_value = NULL; - - Py_BEGIN_CRITICAL_SECTION(self); - return_value = _zstd_ZstdDict_as_undigested_dict_get_impl((ZstdDict *)self); - Py_END_CRITICAL_SECTION(); - - return return_value; + return _zstd_ZstdDict_as_undigested_dict_get_impl((ZstdDict *)self); } PyDoc_STRVAR(_zstd_ZstdDict_as_prefix__doc__, @@ -194,12 +181,6 @@ _zstd_ZstdDict_as_prefix_get_impl(ZstdDict *self); static PyObject * _zstd_ZstdDict_as_prefix_get(PyObject *self, void *Py_UNUSED(context)) { - PyObject *return_value = NULL; - - Py_BEGIN_CRITICAL_SECTION(self); - return_value = _zstd_ZstdDict_as_prefix_get_impl((ZstdDict *)self); - Py_END_CRITICAL_SECTION(); - - return return_value; + return _zstd_ZstdDict_as_prefix_get_impl((ZstdDict *)self); } -/*[clinic end generated code: output=bfb31c1187477afd input=a9049054013a1b77]*/ +/*[clinic end generated code: output=8692eabee4e0d1fe input=a9049054013a1b77]*/ diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c index 38baee2be1e95b..8f934858ef784f 100644 --- a/Modules/_zstd/compressor.c +++ b/Modules/_zstd/compressor.c @@ -17,6 +17,7 @@ class _zstd.ZstdCompressor "ZstdCompressor *" "&zstd_compressor_type_spec" #include "_zstdmodule.h" #include "buffer.h" #include "zstddict.h" +#include "internal/pycore_lock.h" // PyMutex_IsLocked #include // offsetof() #include // ZSTD_*() @@ -38,6 +39,9 @@ typedef struct { /* Compression level */ int compression_level; + + /* Lock to protect the compression context */ + PyMutex lock; } ZstdCompressor; #define ZstdCompressor_CAST(op) ((ZstdCompressor *)op) @@ -149,12 +153,12 @@ capsule_free_cdict(PyObject *capsule) ZSTD_CDict * _get_CDict(ZstdDict *self, int compressionLevel) { + assert(PyMutex_IsLocked(&self->lock)); PyObject *level = NULL; - PyObject *capsule; + PyObject *capsule = NULL; ZSTD_CDict *cdict; + int ret; - // TODO(emmatyping): refactor critical section code into a lock_held function - Py_BEGIN_CRITICAL_SECTION(self); /* int level object */ level = PyLong_FromLong(compressionLevel); @@ -163,12 +167,11 @@ _get_CDict(ZstdDict *self, int compressionLevel) } /* Get PyCapsule object from self->c_dicts */ - capsule = PyDict_GetItemWithError(self->c_dicts, level); + ret = PyDict_GetItemRef(self->c_dicts, level, &capsule); + if (ret < 0) { + goto error; + } if (capsule == NULL) { - if (PyErr_Occurred()) { - goto error; - } - /* Create ZSTD_CDict instance */ char *dict_buffer = PyBytes_AS_STRING(self->dict_content); Py_ssize_t dict_len = Py_SIZE(self->dict_content); @@ -196,11 +199,10 @@ _get_CDict(ZstdDict *self, int compressionLevel) } /* Add PyCapsule object to self->c_dicts */ - if (PyDict_SetItem(self->c_dicts, level, capsule) < 0) { - Py_DECREF(capsule); + ret = PyDict_SetItem(self->c_dicts, level, capsule); + if (ret < 0) { goto error; } - Py_DECREF(capsule); } else { /* ZSTD_CDict instance already exists */ @@ -212,15 +214,55 @@ _get_CDict(ZstdDict *self, int compressionLevel) cdict = NULL; success: Py_XDECREF(level); - Py_END_CRITICAL_SECTION(); + Py_XDECREF(capsule); return cdict; } static int -_zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) +_zstd_load_impl(ZstdCompressor *self, ZstdDict *zd, + _zstd_state *mod_state, int type) { - size_t zstd_ret; + if (type == DICT_TYPE_DIGESTED) { + /* Get ZSTD_CDict */ + ZSTD_CDict *c_dict = _get_CDict(zd, self->compression_level); + if (c_dict == NULL) { + return -1; + } + /* Reference a prepared dictionary. + It overrides some compression context's parameters. */ + zstd_ret = ZSTD_CCtx_refCDict(self->cctx, c_dict); + } + else if (type == DICT_TYPE_UNDIGESTED) { + /* Load a dictionary. + It doesn't override compression context's parameters. */ + zstd_ret = ZSTD_CCtx_loadDictionary( + self->cctx, + PyBytes_AS_STRING(zd->dict_content), + Py_SIZE(zd->dict_content)); + } + else if (type == DICT_TYPE_PREFIX) { + /* Load a prefix */ + zstd_ret = ZSTD_CCtx_refPrefix( + self->cctx, + PyBytes_AS_STRING(zd->dict_content), + Py_SIZE(zd->dict_content)); + } + else { + Py_UNREACHABLE(); + } + + /* Check error */ + if (ZSTD_isError(zstd_ret)) { + set_zstd_error(mod_state, ERR_LOAD_C_DICT, zstd_ret); + return -1; + } + return 0; +} + +static int +_zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) +{ _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); if (mod_state == NULL) { return -1; @@ -237,7 +279,10 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) /* When compressing, use undigested dictionary by default. */ zd = (ZstdDict*)dict; type = DICT_TYPE_UNDIGESTED; - goto load; + PyMutex_Lock(&zd->lock); + ret = _zstd_load_impl(self, zd, mod_state, type); + PyMutex_Unlock(&zd->lock); + return ret; } /* Check (ZstdDict, type) */ @@ -257,7 +302,10 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) { assert(type >= 0); zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0); - goto load; + PyMutex_Lock(&zd->lock); + ret = _zstd_load_impl(self, zd, mod_state, type); + PyMutex_Unlock(&zd->lock); + return ret; } } } @@ -266,49 +314,6 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) PyErr_SetString(PyExc_TypeError, "zstd_dict argument should be ZstdDict object."); return -1; - -load: - if (type == DICT_TYPE_DIGESTED) { - /* Get ZSTD_CDict */ - ZSTD_CDict *c_dict = _get_CDict(zd, self->compression_level); - if (c_dict == NULL) { - return -1; - } - /* Reference a prepared dictionary. - It overrides some compression context's parameters. */ - Py_BEGIN_CRITICAL_SECTION(self); - zstd_ret = ZSTD_CCtx_refCDict(self->cctx, c_dict); - Py_END_CRITICAL_SECTION(); - } - else if (type == DICT_TYPE_UNDIGESTED) { - /* Load a dictionary. - It doesn't override compression context's parameters. */ - Py_BEGIN_CRITICAL_SECTION2(self, zd); - zstd_ret = ZSTD_CCtx_loadDictionary( - self->cctx, - PyBytes_AS_STRING(zd->dict_content), - Py_SIZE(zd->dict_content)); - Py_END_CRITICAL_SECTION2(); - } - else if (type == DICT_TYPE_PREFIX) { - /* Load a prefix */ - Py_BEGIN_CRITICAL_SECTION2(self, zd); - zstd_ret = ZSTD_CCtx_refPrefix( - self->cctx, - PyBytes_AS_STRING(zd->dict_content), - Py_SIZE(zd->dict_content)); - Py_END_CRITICAL_SECTION2(); - } - else { - Py_UNREACHABLE(); - } - - /* Check error */ - if (ZSTD_isError(zstd_ret)) { - set_zstd_error(mod_state, ERR_LOAD_C_DICT, zstd_ret); - return -1; - } - return 0; } /*[clinic input] @@ -339,6 +344,7 @@ _zstd_ZstdCompressor_new_impl(PyTypeObject *type, PyObject *level, self->use_multithread = 0; self->dict = NULL; + self->lock = (PyMutex){0}; /* Compression context */ self->cctx = ZSTD_createCCtx(); @@ -403,6 +409,8 @@ ZstdCompressor_dealloc(PyObject *ob) ZSTD_freeCCtx(self->cctx); } + assert(!PyMutex_IsLocked(&self->lock)); + /* Py_XDECREF the dict after free the compression context */ Py_CLEAR(self->dict); @@ -412,9 +420,10 @@ ZstdCompressor_dealloc(PyObject *ob) } static PyObject * -compress_impl(ZstdCompressor *self, Py_buffer *data, - ZSTD_EndDirective end_directive) +compress_lock_held(ZstdCompressor *self, Py_buffer *data, + ZSTD_EndDirective end_directive) { + assert(PyMutex_IsLocked(&self->lock)); ZSTD_inBuffer in; ZSTD_outBuffer out; _BlocksOutputBuffer buffer = {.list = NULL}; @@ -495,8 +504,9 @@ mt_continue_should_break(ZSTD_inBuffer *in, ZSTD_outBuffer *out) #endif static PyObject * -compress_mt_continue_impl(ZstdCompressor *self, Py_buffer *data) +compress_mt_continue_lock_held(ZstdCompressor *self, Py_buffer *data) { + assert(PyMutex_IsLocked(&self->lock)); ZSTD_inBuffer in; ZSTD_outBuffer out; _BlocksOutputBuffer buffer = {.list = NULL}; @@ -529,7 +539,7 @@ compress_mt_continue_impl(ZstdCompressor *self, Py_buffer *data) goto error; } - /* Like compress_impl(), output as much as possible. */ + /* Like compress_lock_held(), output as much as possible. */ if (out.pos == out.size) { if (_OutputBuffer_Grow(&buffer, &out) < 0) { goto error; @@ -588,14 +598,14 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data, } /* Thread-safe code */ - Py_BEGIN_CRITICAL_SECTION(self); + PyMutex_Lock(&self->lock); /* Compress */ if (self->use_multithread && mode == ZSTD_e_continue) { - ret = compress_mt_continue_impl(self, data); + ret = compress_mt_continue_lock_held(self, data); } else { - ret = compress_impl(self, data, mode); + ret = compress_lock_held(self, data, mode); } if (ret) { @@ -607,7 +617,7 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data, /* Resetting cctx's session never fail */ ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only); } - Py_END_CRITICAL_SECTION(); + PyMutex_Unlock(&self->lock); return ret; } @@ -642,8 +652,9 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode) } /* Thread-safe code */ - Py_BEGIN_CRITICAL_SECTION(self); - ret = compress_impl(self, NULL, mode); + PyMutex_Lock(&self->lock); + + ret = compress_lock_held(self, NULL, mode); if (ret) { self->last_mode = mode; @@ -654,7 +665,7 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode) /* Resetting cctx's session never fail */ ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only); } - Py_END_CRITICAL_SECTION(); + PyMutex_Unlock(&self->lock); return ret; } diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c index 58f9c9f804e549..e299f73b071353 100644 --- a/Modules/_zstd/decompressor.c +++ b/Modules/_zstd/decompressor.c @@ -17,6 +17,7 @@ class _zstd.ZstdDecompressor "ZstdDecompressor *" "&zstd_decompressor_type_spec" #include "_zstdmodule.h" #include "buffer.h" #include "zstddict.h" +#include "internal/pycore_lock.h" // PyMutex_IsLocked #include // bool #include // offsetof() @@ -45,6 +46,9 @@ typedef struct { /* For ZstdDecompressor, 0 or 1. 1 means the end of the first frame has been reached. */ bool eof; + + /* Lock to protect the decompression context */ + PyMutex lock; } ZstdDecompressor; #define ZstdDecompressor_CAST(op) ((ZstdDecompressor *)op) @@ -54,6 +58,7 @@ typedef struct { static inline ZSTD_DDict * _get_DDict(ZstdDict *self) { + assert(PyMutex_IsLocked(&self->lock)); ZSTD_DDict *ret; /* Already created */ @@ -61,15 +66,14 @@ _get_DDict(ZstdDict *self) return self->d_dict; } - Py_BEGIN_CRITICAL_SECTION(self); if (self->d_dict == NULL) { /* Create ZSTD_DDict instance from dictionary content */ char *dict_buffer = PyBytes_AS_STRING(self->dict_content); Py_ssize_t dict_len = Py_SIZE(self->dict_content); Py_BEGIN_ALLOW_THREADS - self->d_dict = ZSTD_createDDict(dict_buffer, - dict_len); + ret = ZSTD_createDDict(dict_buffer, dict_len); Py_END_ALLOW_THREADS + self->d_dict = ret; if (self->d_dict == NULL) { _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); @@ -81,11 +85,7 @@ _get_DDict(ZstdDict *self) } } - /* Don't lose any exception */ - ret = self->d_dict; - Py_END_CRITICAL_SECTION(); - - return ret; + return self->d_dict; } /* Set decompression parameters to decompression context */ @@ -134,9 +134,7 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options) } /* Set parameter to compression context */ - Py_BEGIN_CRITICAL_SECTION(self); zstd_ret = ZSTD_DCtx_setParameter(self->dctx, key_v, value_v); - Py_END_CRITICAL_SECTION(); /* Check error */ if (ZSTD_isError(zstd_ret)) { @@ -147,11 +145,53 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options) return 0; } +static int +_zstd_load_impl(ZstdDecompressor *self, ZstdDict *zd, + _zstd_state *mod_state, int type) +{ + size_t zstd_ret; + if (type == DICT_TYPE_DIGESTED) { + /* Get ZSTD_DDict */ + ZSTD_DDict *d_dict = _get_DDict(zd); + if (d_dict == NULL) { + return -1; + } + /* Reference a prepared dictionary */ + zstd_ret = ZSTD_DCtx_refDDict(self->dctx, d_dict); + } + else if (type == DICT_TYPE_UNDIGESTED) { + /* Load a dictionary */ + zstd_ret = ZSTD_DCtx_loadDictionary( + self->dctx, + PyBytes_AS_STRING(zd->dict_content), + Py_SIZE(zd->dict_content)); + } + else if (type == DICT_TYPE_PREFIX) { + /* Load a prefix */ + zstd_ret = ZSTD_DCtx_refPrefix( + self->dctx, + PyBytes_AS_STRING(zd->dict_content), + Py_SIZE(zd->dict_content)); + } + else { + /* Impossible code path */ + PyErr_SetString(PyExc_SystemError, + "load_d_dict() impossible code path"); + return -1; + } + + /* Check error */ + if (ZSTD_isError(zstd_ret)) { + set_zstd_error(mod_state, ERR_LOAD_D_DICT, zstd_ret); + return -1; + } + return 0; +} + /* Load dictionary or prefix to decompression context */ static int _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) { - size_t zstd_ret; _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); if (mod_state == NULL) { return -1; @@ -168,7 +208,10 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) /* When decompressing, use digested dictionary by default. */ zd = (ZstdDict*)dict; type = DICT_TYPE_DIGESTED; - goto load; + PyMutex_Lock(&zd->lock); + ret = _zstd_load_impl(self, zd, mod_state, type); + PyMutex_Unlock(&zd->lock); + return ret; } /* Check (ZstdDict, type) */ @@ -188,7 +231,10 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) { assert(type >= 0); zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0); - goto load; + PyMutex_Lock(&zd->lock); + ret = _zstd_load_impl(self, zd, mod_state, type); + PyMutex_Unlock(&zd->lock); + return ret; } } } @@ -197,50 +243,6 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) PyErr_SetString(PyExc_TypeError, "zstd_dict argument should be ZstdDict object."); return -1; - -load: - if (type == DICT_TYPE_DIGESTED) { - /* Get ZSTD_DDict */ - ZSTD_DDict *d_dict = _get_DDict(zd); - if (d_dict == NULL) { - return -1; - } - /* Reference a prepared dictionary */ - Py_BEGIN_CRITICAL_SECTION(self); - zstd_ret = ZSTD_DCtx_refDDict(self->dctx, d_dict); - Py_END_CRITICAL_SECTION(); - } - else if (type == DICT_TYPE_UNDIGESTED) { - /* Load a dictionary */ - Py_BEGIN_CRITICAL_SECTION2(self, zd); - zstd_ret = ZSTD_DCtx_loadDictionary( - self->dctx, - PyBytes_AS_STRING(zd->dict_content), - Py_SIZE(zd->dict_content)); - Py_END_CRITICAL_SECTION2(); - } - else if (type == DICT_TYPE_PREFIX) { - /* Load a prefix */ - Py_BEGIN_CRITICAL_SECTION2(self, zd); - zstd_ret = ZSTD_DCtx_refPrefix( - self->dctx, - PyBytes_AS_STRING(zd->dict_content), - Py_SIZE(zd->dict_content)); - Py_END_CRITICAL_SECTION2(); - } - else { - /* Impossible code path */ - PyErr_SetString(PyExc_SystemError, - "load_d_dict() impossible code path"); - return -1; - } - - /* Check error */ - if (ZSTD_isError(zstd_ret)) { - set_zstd_error(mod_state, ERR_LOAD_D_DICT, zstd_ret); - return -1; - } - return 0; } /* @@ -268,8 +270,8 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) Note, decompressing "an empty input" in any case will make it > 0. */ static PyObject * -decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in, - Py_ssize_t max_length) +decompress_lock_held(ZstdDecompressor *self, ZSTD_inBuffer *in, + Py_ssize_t max_length) { size_t zstd_ret; ZSTD_outBuffer out; @@ -339,10 +341,9 @@ decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in, } static void -decompressor_reset_session(ZstdDecompressor *self) +decompressor_reset_session_lock_held(ZstdDecompressor *self) { - // TODO(emmatyping): use _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED here - // and ensure lock is always held + assert(PyMutex_IsLocked(&self->lock)); /* Reset variables */ self->in_begin = 0; @@ -359,8 +360,10 @@ decompressor_reset_session(ZstdDecompressor *self) } static PyObject * -stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length) +stream_decompress_lock_held(ZstdDecompressor *self, Py_buffer *data, + Py_ssize_t max_length) { + assert(PyMutex_IsLocked(&self->lock)); ZSTD_inBuffer in; PyObject *ret = NULL; int use_input_buffer; @@ -456,7 +459,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length assert(in.pos == 0); /* Decompress */ - ret = decompress_impl(self, &in, max_length); + ret = decompress_lock_held(self, &in, max_length); if (ret == NULL) { goto error; } @@ -517,7 +520,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length error: /* Reset decompressor's states/session */ - decompressor_reset_session(self); + decompressor_reset_session_lock_held(self); Py_CLEAR(ret); return NULL; @@ -555,6 +558,7 @@ _zstd_ZstdDecompressor_new_impl(PyTypeObject *type, PyObject *zstd_dict, self->unused_data = NULL; self->eof = 0; self->dict = NULL; + self->lock = (PyMutex){0}; /* needs_input flag */ self->needs_input = 1; @@ -608,6 +612,8 @@ ZstdDecompressor_dealloc(PyObject *ob) ZSTD_freeDCtx(self->dctx); } + assert(!PyMutex_IsLocked(&self->lock)); + /* Py_CLEAR the dict after free decompression context */ Py_CLEAR(self->dict); @@ -623,7 +629,6 @@ ZstdDecompressor_dealloc(PyObject *ob) } /*[clinic input] -@critical_section @getter _zstd.ZstdDecompressor.unused_data @@ -635,11 +640,14 @@ decompressed, unused input data after the frame. Otherwise this will be b''. static PyObject * _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self) -/*[clinic end generated code: output=f3a20940f11b6b09 input=5233800bef00df04]*/ +/*[clinic end generated code: output=f3a20940f11b6b09 input=54d41ecd681a3444]*/ { PyObject *ret; + PyMutex_Lock(&self->lock); + if (!self->eof) { + PyMutex_Unlock(&self->lock); return Py_GetConstant(Py_CONSTANT_EMPTY_BYTES); } else { @@ -656,6 +664,7 @@ _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self) } } + PyMutex_Unlock(&self->lock); return ret; } @@ -693,10 +702,9 @@ _zstd_ZstdDecompressor_decompress_impl(ZstdDecompressor *self, { PyObject *ret; /* Thread-safe code */ - Py_BEGIN_CRITICAL_SECTION(self); - - ret = stream_decompress(self, data, max_length); - Py_END_CRITICAL_SECTION(); + PyMutex_Lock(&self->lock); + ret = stream_decompress_lock_held(self, data, max_length); + PyMutex_Unlock(&self->lock); return ret; } diff --git a/Modules/_zstd/zstddict.c b/Modules/_zstd/zstddict.c index 7df187a6fa69d7..39828c9b36b5c2 100644 --- a/Modules/_zstd/zstddict.c +++ b/Modules/_zstd/zstddict.c @@ -17,6 +17,7 @@ class _zstd.ZstdDict "ZstdDict *" "&zstd_dict_type_spec" #include "_zstdmodule.h" #include "zstddict.h" #include "clinic/zstddict.c.h" +#include "internal/pycore_lock.h" // PyMutex_IsLocked #include // ZSTD_freeDDict(), ZSTD_getDictID_fromDict() @@ -53,6 +54,7 @@ _zstd_ZstdDict_new_impl(PyTypeObject *type, PyObject *dict_content, self->dict_content = NULL; self->d_dict = NULL; self->dict_id = 0; + self->lock = (PyMutex){0}; /* ZSTD_CDict dict */ self->c_dicts = PyDict_New(); @@ -109,6 +111,8 @@ ZstdDict_dealloc(PyObject *ob) ZSTD_freeDDict(self->d_dict); } + assert(!PyMutex_IsLocked(&self->lock)); + /* Release dict_content after Free ZSTD_CDict/ZSTD_DDict instances */ Py_CLEAR(self->dict_content); Py_CLEAR(self->c_dicts); @@ -143,7 +147,6 @@ static PyMemberDef ZstdDict_members[] = { }; /*[clinic input] -@critical_section @getter _zstd.ZstdDict.as_digested_dict @@ -160,13 +163,12 @@ Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_digeste static PyObject * _zstd_ZstdDict_as_digested_dict_get_impl(ZstdDict *self) -/*[clinic end generated code: output=09b086e7a7320dbb input=585448c79f31f74a]*/ +/*[clinic end generated code: output=09b086e7a7320dbb input=10cd2b6165931b77]*/ { return Py_BuildValue("Oi", self, DICT_TYPE_DIGESTED); } /*[clinic input] -@critical_section @getter _zstd.ZstdDict.as_undigested_dict @@ -181,13 +183,12 @@ Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_undiges static PyObject * _zstd_ZstdDict_as_undigested_dict_get_impl(ZstdDict *self) -/*[clinic end generated code: output=43c7a989e6d4253a input=022b0829ffb1c220]*/ +/*[clinic end generated code: output=43c7a989e6d4253a input=11e5f5df690a85b4]*/ { return Py_BuildValue("Oi", self, DICT_TYPE_UNDIGESTED); } /*[clinic input] -@critical_section @getter _zstd.ZstdDict.as_prefix @@ -202,7 +203,7 @@ Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_prefix) static PyObject * _zstd_ZstdDict_as_prefix_get_impl(ZstdDict *self) -/*[clinic end generated code: output=6f7130c356595a16 input=09fb82a6a5407e87]*/ +/*[clinic end generated code: output=6f7130c356595a16 input=b028e0ae6ec4292b]*/ { return Py_BuildValue("Oi", self, DICT_TYPE_PREFIX); } diff --git a/Modules/_zstd/zstddict.h b/Modules/_zstd/zstddict.h index e8a55a3670b869..dcba0f21852087 100644 --- a/Modules/_zstd/zstddict.h +++ b/Modules/_zstd/zstddict.h @@ -19,6 +19,9 @@ typedef struct { PyObject *dict_content; /* Dictionary id */ uint32_t dict_id; + + /* Lock to protect the digested dictionaries */ + PyMutex lock; } ZstdDict; #endif // !ZSTD_DICT_H