rustpython_vm/protocol/
buffer.rs

1//! Buffer protocol
2//! https://docs.python.org/3/c-api/buffer.html
3
4use crate::{
5    common::{
6        borrow::{BorrowedValue, BorrowedValueMut},
7        lock::{MapImmutable, PyMutex, PyMutexGuard},
8    },
9    object::PyObjectPayload,
10    sliceable::SequenceIndexOp,
11    types::Unconstructible,
12    Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, VirtualMachine,
13};
14use itertools::Itertools;
15use std::{borrow::Cow, fmt::Debug, ops::Range};
16
17pub struct BufferMethods {
18    pub obj_bytes: fn(&PyBuffer) -> BorrowedValue<[u8]>,
19    pub obj_bytes_mut: fn(&PyBuffer) -> BorrowedValueMut<[u8]>,
20    pub release: fn(&PyBuffer),
21    pub retain: fn(&PyBuffer),
22}
23
24impl Debug for BufferMethods {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("BufferMethods")
27            .field("obj_bytes", &(self.obj_bytes as usize))
28            .field("obj_bytes_mut", &(self.obj_bytes_mut as usize))
29            .field("release", &(self.release as usize))
30            .field("retain", &(self.retain as usize))
31            .finish()
32    }
33}
34
35#[derive(Debug, Clone, Traverse)]
36pub struct PyBuffer {
37    pub obj: PyObjectRef,
38    #[pytraverse(skip)]
39    pub desc: BufferDescriptor,
40    #[pytraverse(skip)]
41    methods: &'static BufferMethods,
42}
43
44impl PyBuffer {
45    pub fn new(obj: PyObjectRef, desc: BufferDescriptor, methods: &'static BufferMethods) -> Self {
46        let zelf = Self {
47            obj,
48            desc: desc.validate(),
49            methods,
50        };
51        zelf.retain();
52        zelf
53    }
54
55    pub fn as_contiguous(&self) -> Option<BorrowedValue<[u8]>> {
56        self.desc
57            .is_contiguous()
58            .then(|| unsafe { self.contiguous_unchecked() })
59    }
60
61    pub fn as_contiguous_mut(&self) -> Option<BorrowedValueMut<[u8]>> {
62        (!self.desc.readonly && self.desc.is_contiguous())
63            .then(|| unsafe { self.contiguous_mut_unchecked() })
64    }
65
66    pub fn from_byte_vector(bytes: Vec<u8>, vm: &VirtualMachine) -> Self {
67        let bytes_len = bytes.len();
68        PyBuffer::new(
69            PyPayload::into_pyobject(VecBuffer::from(bytes), vm),
70            BufferDescriptor::simple(bytes_len, true),
71            &VEC_BUFFER_METHODS,
72        )
73    }
74
75    /// # Safety
76    /// assume the buffer is contiguous
77    pub unsafe fn contiguous_unchecked(&self) -> BorrowedValue<[u8]> {
78        self.obj_bytes()
79    }
80
81    /// # Safety
82    /// assume the buffer is contiguous and writable
83    pub unsafe fn contiguous_mut_unchecked(&self) -> BorrowedValueMut<[u8]> {
84        self.obj_bytes_mut()
85    }
86
87    pub fn append_to(&self, buf: &mut Vec<u8>) {
88        if let Some(bytes) = self.as_contiguous() {
89            buf.extend_from_slice(&bytes);
90        } else {
91            let bytes = &*self.obj_bytes();
92            self.desc.for_each_segment(true, |range| {
93                buf.extend_from_slice(&bytes[range.start as usize..range.end as usize])
94            });
95        }
96    }
97
98    pub fn contiguous_or_collect<R, F: FnOnce(&[u8]) -> R>(&self, f: F) -> R {
99        let borrowed;
100        let mut collected;
101        let v = if let Some(bytes) = self.as_contiguous() {
102            borrowed = bytes;
103            &*borrowed
104        } else {
105            collected = vec![];
106            self.append_to(&mut collected);
107            &collected
108        };
109        f(v)
110    }
111
112    pub fn obj_as<T: PyObjectPayload>(&self) -> &Py<T> {
113        unsafe { self.obj.downcast_unchecked_ref() }
114    }
115
116    pub fn obj_bytes(&self) -> BorrowedValue<[u8]> {
117        (self.methods.obj_bytes)(self)
118    }
119
120    pub fn obj_bytes_mut(&self) -> BorrowedValueMut<[u8]> {
121        (self.methods.obj_bytes_mut)(self)
122    }
123
124    pub fn release(&self) {
125        (self.methods.release)(self)
126    }
127
128    pub fn retain(&self) {
129        (self.methods.retain)(self)
130    }
131
132    // drop PyBuffer without calling release
133    // after this function, the owner should use forget()
134    // or wrap PyBuffer in the ManaullyDrop to prevent drop()
135    pub(crate) unsafe fn drop_without_release(&mut self) {
136        std::ptr::drop_in_place(&mut self.obj);
137        std::ptr::drop_in_place(&mut self.desc);
138    }
139}
140
141impl<'a> TryFromBorrowedObject<'a> for PyBuffer {
142    fn try_from_borrowed_object(vm: &VirtualMachine, obj: &'a PyObject) -> PyResult<Self> {
143        let cls = obj.class();
144        let as_buffer = cls.mro_find_map(|cls| cls.slots.as_buffer);
145        if let Some(f) = as_buffer {
146            return f(obj, vm);
147        }
148        Err(vm.new_type_error(format!(
149            "a bytes-like object is required, not '{}'",
150            cls.name()
151        )))
152    }
153}
154
155impl Drop for PyBuffer {
156    fn drop(&mut self) {
157        self.release();
158    }
159}
160
161#[derive(Debug, Clone)]
162pub struct BufferDescriptor {
163    /// product(shape) * itemsize
164    /// bytes length, but not the length for obj_bytes() even is contiguous
165    pub len: usize,
166    pub readonly: bool,
167    pub itemsize: usize,
168    pub format: Cow<'static, str>,
169    /// (shape, stride, suboffset) for each dimension
170    pub dim_desc: Vec<(usize, isize, isize)>,
171    // TODO: flags
172}
173
174impl BufferDescriptor {
175    pub fn simple(bytes_len: usize, readonly: bool) -> Self {
176        Self {
177            len: bytes_len,
178            readonly,
179            itemsize: 1,
180            format: Cow::Borrowed("B"),
181            dim_desc: vec![(bytes_len, 1, 0)],
182        }
183    }
184
185    pub fn format(
186        bytes_len: usize,
187        readonly: bool,
188        itemsize: usize,
189        format: Cow<'static, str>,
190    ) -> Self {
191        Self {
192            len: bytes_len,
193            readonly,
194            itemsize,
195            format,
196            dim_desc: vec![(bytes_len / itemsize, itemsize as isize, 0)],
197        }
198    }
199
200    #[cfg(debug_assertions)]
201    pub fn validate(self) -> Self {
202        assert!(self.itemsize != 0);
203        assert!(self.ndim() != 0);
204        let mut shape_product = 1;
205        for (shape, stride, suboffset) in self.dim_desc.iter().cloned() {
206            shape_product *= shape;
207            assert!(suboffset >= 0);
208            assert!(stride != 0);
209        }
210        assert!(shape_product * self.itemsize == self.len);
211        self
212    }
213
214    #[cfg(not(debug_assertions))]
215    pub fn validate(self) -> Self {
216        self
217    }
218
219    pub fn ndim(&self) -> usize {
220        self.dim_desc.len()
221    }
222
223    pub fn is_contiguous(&self) -> bool {
224        if self.len == 0 {
225            return true;
226        }
227        let mut sd = self.itemsize;
228        for (shape, stride, _) in self.dim_desc.iter().cloned().rev() {
229            if shape > 1 && stride != sd as isize {
230                return false;
231            }
232            sd *= shape;
233        }
234        true
235    }
236
237    /// this function do not check the bound
238    /// panic if indices.len() != ndim
239    pub fn fast_position(&self, indices: &[usize]) -> isize {
240        let mut pos = 0;
241        for (i, (_, stride, suboffset)) in indices
242            .iter()
243            .cloned()
244            .zip_eq(self.dim_desc.iter().cloned())
245        {
246            pos += i as isize * stride + suboffset;
247        }
248        pos
249    }
250
251    /// panic if indices.len() != ndim
252    pub fn position(&self, indices: &[isize], vm: &VirtualMachine) -> PyResult<isize> {
253        let mut pos = 0;
254        for (i, (shape, stride, suboffset)) in indices
255            .iter()
256            .cloned()
257            .zip_eq(self.dim_desc.iter().cloned())
258        {
259            let i = i.wrapped_at(shape).ok_or_else(|| {
260                vm.new_index_error(format!("index out of bounds on dimension {i}"))
261            })?;
262            pos += i as isize * stride + suboffset;
263        }
264        Ok(pos)
265    }
266
267    pub fn for_each_segment<F>(&self, try_conti: bool, mut f: F)
268    where
269        F: FnMut(Range<isize>),
270    {
271        if self.ndim() == 0 {
272            f(0..self.itemsize as isize);
273            return;
274        }
275        if try_conti && self.is_last_dim_contiguous() {
276            self._for_each_segment::<_, true>(0, 0, &mut f);
277        } else {
278            self._for_each_segment::<_, false>(0, 0, &mut f);
279        }
280    }
281
282    fn _for_each_segment<F, const CONTI: bool>(&self, mut index: isize, dim: usize, f: &mut F)
283    where
284        F: FnMut(Range<isize>),
285    {
286        let (shape, stride, suboffset) = self.dim_desc[dim];
287        if dim + 1 == self.ndim() {
288            if CONTI {
289                f(index..index + (shape * self.itemsize) as isize);
290            } else {
291                for _ in 0..shape {
292                    let pos = index + suboffset;
293                    f(pos..pos + self.itemsize as isize);
294                    index += stride;
295                }
296            }
297            return;
298        }
299        for _ in 0..shape {
300            self._for_each_segment::<F, CONTI>(index + suboffset, dim + 1, f);
301            index += stride;
302        }
303    }
304
305    /// zip two BufferDescriptor with the same shape
306    pub fn zip_eq<F>(&self, other: &Self, try_conti: bool, mut f: F)
307    where
308        F: FnMut(Range<isize>, Range<isize>) -> bool,
309    {
310        if self.ndim() == 0 {
311            f(0..self.itemsize as isize, 0..other.itemsize as isize);
312            return;
313        }
314        if try_conti && self.is_last_dim_contiguous() {
315            self._zip_eq::<_, true>(other, 0, 0, 0, &mut f);
316        } else {
317            self._zip_eq::<_, false>(other, 0, 0, 0, &mut f);
318        }
319    }
320
321    fn _zip_eq<F, const CONTI: bool>(
322        &self,
323        other: &Self,
324        mut a_index: isize,
325        mut b_index: isize,
326        dim: usize,
327        f: &mut F,
328    ) where
329        F: FnMut(Range<isize>, Range<isize>) -> bool,
330    {
331        let (shape, a_stride, a_suboffset) = self.dim_desc[dim];
332        let (_b_shape, b_stride, b_suboffset) = other.dim_desc[dim];
333        debug_assert_eq!(shape, _b_shape);
334        if dim + 1 == self.ndim() {
335            if CONTI {
336                if f(
337                    a_index..a_index + (shape * self.itemsize) as isize,
338                    b_index..b_index + (shape * other.itemsize) as isize,
339                ) {
340                    return;
341                }
342            } else {
343                for _ in 0..shape {
344                    let a_pos = a_index + a_suboffset;
345                    let b_pos = b_index + b_suboffset;
346                    if f(
347                        a_pos..a_pos + self.itemsize as isize,
348                        b_pos..b_pos + other.itemsize as isize,
349                    ) {
350                        return;
351                    }
352                    a_index += a_stride;
353                    b_index += b_stride;
354                }
355            }
356            return;
357        }
358
359        for _ in 0..shape {
360            self._zip_eq::<F, CONTI>(
361                other,
362                a_index + a_suboffset,
363                b_index + b_suboffset,
364                dim + 1,
365                f,
366            );
367            a_index += a_stride;
368            b_index += b_stride;
369        }
370    }
371
372    fn is_last_dim_contiguous(&self) -> bool {
373        let (_, stride, suboffset) = self.dim_desc[self.ndim() - 1];
374        suboffset == 0 && stride == self.itemsize as isize
375    }
376
377    pub fn is_zero_in_shape(&self) -> bool {
378        for (shape, _, _) in self.dim_desc.iter().cloned() {
379            if shape == 0 {
380                return true;
381            }
382        }
383        false
384    }
385
386    // TODO: support fortain order
387}
388
389pub trait BufferResizeGuard {
390    type Resizable<'a>: 'a
391    where
392        Self: 'a;
393    fn try_resizable_opt(&self) -> Option<Self::Resizable<'_>>;
394    fn try_resizable(&self, vm: &VirtualMachine) -> PyResult<Self::Resizable<'_>> {
395        self.try_resizable_opt().ok_or_else(|| {
396            vm.new_buffer_error("Existing exports of data: object cannot be re-sized".to_owned())
397        })
398    }
399}
400
401#[pyclass(module = false, name = "vec_buffer")]
402#[derive(Debug, PyPayload)]
403pub struct VecBuffer {
404    data: PyMutex<Vec<u8>>,
405}
406
407#[pyclass(flags(BASETYPE), with(Unconstructible))]
408impl VecBuffer {
409    pub fn take(&self) -> Vec<u8> {
410        std::mem::take(&mut self.data.lock())
411    }
412}
413
414impl From<Vec<u8>> for VecBuffer {
415    fn from(data: Vec<u8>) -> Self {
416        Self {
417            data: PyMutex::new(data),
418        }
419    }
420}
421
422impl Unconstructible for VecBuffer {}
423
424impl PyRef<VecBuffer> {
425    pub fn into_pybuffer(self, readonly: bool) -> PyBuffer {
426        let len = self.data.lock().len();
427        PyBuffer::new(
428            self.into(),
429            BufferDescriptor::simple(len, readonly),
430            &VEC_BUFFER_METHODS,
431        )
432    }
433
434    pub fn into_pybuffer_with_descriptor(self, desc: BufferDescriptor) -> PyBuffer {
435        PyBuffer::new(self.into(), desc, &VEC_BUFFER_METHODS)
436    }
437}
438
439static VEC_BUFFER_METHODS: BufferMethods = BufferMethods {
440    obj_bytes: |buffer| {
441        PyMutexGuard::map_immutable(buffer.obj_as::<VecBuffer>().data.lock(), |x| x.as_slice())
442            .into()
443    },
444    obj_bytes_mut: |buffer| {
445        PyMutexGuard::map(buffer.obj_as::<VecBuffer>().data.lock(), |x| {
446            x.as_mut_slice()
447        })
448        .into()
449    },
450    release: |_| {},
451    retain: |_| {},
452};
Morty Proxy This is a proxified and sanitized view of the page, visit original site.