Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c64a214

Browse filesBrowse files
authored
gh-132983: Refactor shared code in train_dict and finalize_dict (GH-134432)
Refactor shared code in train_dict and finalize_dict
1 parent 0a68068 commit c64a214
Copy full SHA for c64a214

File tree

Expand file treeCollapse file tree

1 file changed

+55
-68
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+55
-68
lines changed

‎Modules/_zstd/_zstdmodule.c

Copy file name to clipboardExpand all lines: Modules/_zstd/_zstdmodule.c
+55-68Lines changed: 55 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,49 @@ get_zstd_state(PyObject *module)
172172
return (_zstd_state *)state;
173173
}
174174

175+
static Py_ssize_t
176+
calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
177+
size_t **chunk_sizes)
178+
{
179+
Py_ssize_t chunks_number;
180+
Py_ssize_t sizes_sum;
181+
Py_ssize_t i;
182+
183+
chunks_number = Py_SIZE(samples_sizes);
184+
if ((size_t) chunks_number > UINT32_MAX) {
185+
PyErr_Format(PyExc_ValueError,
186+
"The number of samples should be <= %u.", UINT32_MAX);
187+
return -1;
188+
}
189+
190+
/* Prepare chunk_sizes */
191+
*chunk_sizes = PyMem_New(size_t, chunks_number);
192+
if (*chunk_sizes == NULL) {
193+
PyErr_NoMemory();
194+
return -1;
195+
}
196+
197+
sizes_sum = 0;
198+
for (i = 0; i < chunks_number; i++) {
199+
PyObject *size = PyTuple_GetItem(samples_sizes, i);
200+
(*chunk_sizes)[i] = PyLong_AsSize_t(size);
201+
if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) {
202+
PyErr_Format(PyExc_ValueError,
203+
"Items in samples_sizes should be an int "
204+
"object, with a value between 0 and %u.", SIZE_MAX);
205+
return -1;
206+
}
207+
sizes_sum += (*chunk_sizes)[i];
208+
}
209+
210+
if (sizes_sum != Py_SIZE(samples_bytes)) {
211+
PyErr_SetString(PyExc_ValueError,
212+
"The samples size tuple doesn't match the concatenation's size.");
213+
return -1;
214+
}
215+
return chunks_number;
216+
}
217+
175218

176219
/*[clinic input]
177220
_zstd.train_dict
@@ -192,54 +235,25 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
192235
PyObject *samples_sizes, Py_ssize_t dict_size)
193236
/*[clinic end generated code: output=8e87fe43935e8f77 input=d20dedb21c72cb62]*/
194237
{
195-
// TODO(emmatyping): The preamble and suffix to this function and _finalize_dict
196-
// are pretty similar. We should see if we can refactor them to share that code.
197-
Py_ssize_t chunks_number;
198-
size_t *chunk_sizes = NULL;
199238
PyObject *dst_dict_bytes = NULL;
239+
size_t *chunk_sizes = NULL;
240+
Py_ssize_t chunks_number;
200241
size_t zstd_ret;
201-
Py_ssize_t sizes_sum;
202-
Py_ssize_t i;
203242

204243
/* Check arguments */
205244
if (dict_size <= 0) {
206245
PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number.");
207246
return NULL;
208247
}
209248

210-
chunks_number = Py_SIZE(samples_sizes);
211-
if ((size_t) chunks_number > UINT32_MAX) {
212-
PyErr_Format(PyExc_ValueError,
213-
"The number of samples should be <= %u.", UINT32_MAX);
249+
/* Check that the samples are valid and get their sizes */
250+
chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
251+
&chunk_sizes);
252+
if (chunks_number < 0)
253+
{
214254
return NULL;
215255
}
216256

217-
/* Prepare chunk_sizes */
218-
chunk_sizes = PyMem_New(size_t, chunks_number);
219-
if (chunk_sizes == NULL) {
220-
PyErr_NoMemory();
221-
goto error;
222-
}
223-
224-
sizes_sum = 0;
225-
for (i = 0; i < chunks_number; i++) {
226-
PyObject *size = PyTuple_GetItem(samples_sizes, i);
227-
chunk_sizes[i] = PyLong_AsSize_t(size);
228-
if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
229-
PyErr_Format(PyExc_ValueError,
230-
"Items in samples_sizes should be an int "
231-
"object, with a value between 0 and %u.", SIZE_MAX);
232-
goto error;
233-
}
234-
sizes_sum += chunk_sizes[i];
235-
}
236-
237-
if (sizes_sum != Py_SIZE(samples_bytes)) {
238-
PyErr_SetString(PyExc_ValueError,
239-
"The samples size tuple doesn't match the concatenation's size.");
240-
goto error;
241-
}
242-
243257
/* Allocate dict buffer */
244258
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
245259
if (dst_dict_bytes == NULL) {
@@ -307,48 +321,21 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
307321
PyObject *dst_dict_bytes = NULL;
308322
size_t zstd_ret;
309323
ZDICT_params_t params;
310-
Py_ssize_t sizes_sum;
311-
Py_ssize_t i;
312324

313325
/* Check arguments */
314326
if (dict_size <= 0) {
315327
PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number.");
316328
return NULL;
317329
}
318330

319-
chunks_number = Py_SIZE(samples_sizes);
320-
if ((size_t) chunks_number > UINT32_MAX) {
321-
PyErr_Format(PyExc_ValueError,
322-
"The number of samples should be <= %u.", UINT32_MAX);
331+
/* Check that the samples are valid and get their sizes */
332+
chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
333+
&chunk_sizes);
334+
if (chunks_number < 0)
335+
{
323336
return NULL;
324337
}
325338

326-
/* Prepare chunk_sizes */
327-
chunk_sizes = PyMem_New(size_t, chunks_number);
328-
if (chunk_sizes == NULL) {
329-
PyErr_NoMemory();
330-
goto error;
331-
}
332-
333-
sizes_sum = 0;
334-
for (i = 0; i < chunks_number; i++) {
335-
PyObject *size = PyTuple_GetItem(samples_sizes, i);
336-
chunk_sizes[i] = PyLong_AsSize_t(size);
337-
if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
338-
PyErr_Format(PyExc_ValueError,
339-
"Items in samples_sizes should be an int "
340-
"object, with a value between 0 and %u.", SIZE_MAX);
341-
goto error;
342-
}
343-
sizes_sum += chunk_sizes[i];
344-
}
345-
346-
if (sizes_sum != Py_SIZE(samples_bytes)) {
347-
PyErr_SetString(PyExc_ValueError,
348-
"The samples size tuple doesn't match the concatenation's size.");
349-
goto error;
350-
}
351-
352339
/* Allocate dict buffer */
353340
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
354341
if (dst_dict_bytes == NULL) {

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.