1use crate::{
5 common::{
6 borrow::{BorrowedValue, BorrowedValueMut},
7 lock::{MapImmutable, PyMutex, PyMutexGuard},
8 },
9 object::PyObjectPayload,
10 sliceable::SequenceIndexOp,
11 types::Unconstructible,
12 Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, VirtualMachine,
13};
14use itertools::Itertools;
15use std::{borrow::Cow, fmt::Debug, ops::Range};
16
17pub struct BufferMethods {
18 pub obj_bytes: fn(&PyBuffer) -> BorrowedValue<[u8]>,
19 pub obj_bytes_mut: fn(&PyBuffer) -> BorrowedValueMut<[u8]>,
20 pub release: fn(&PyBuffer),
21 pub retain: fn(&PyBuffer),
22}
23
24impl Debug for BufferMethods {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("BufferMethods")
27 .field("obj_bytes", &(self.obj_bytes as usize))
28 .field("obj_bytes_mut", &(self.obj_bytes_mut as usize))
29 .field("release", &(self.release as usize))
30 .field("retain", &(self.retain as usize))
31 .finish()
32 }
33}
34
35#[derive(Debug, Clone, Traverse)]
36pub struct PyBuffer {
37 pub obj: PyObjectRef,
38 #[pytraverse(skip)]
39 pub desc: BufferDescriptor,
40 #[pytraverse(skip)]
41 methods: &'static BufferMethods,
42}
43
44impl PyBuffer {
45 pub fn new(obj: PyObjectRef, desc: BufferDescriptor, methods: &'static BufferMethods) -> Self {
46 let zelf = Self {
47 obj,
48 desc: desc.validate(),
49 methods,
50 };
51 zelf.retain();
52 zelf
53 }
54
55 pub fn as_contiguous(&self) -> Option<BorrowedValue<[u8]>> {
56 self.desc
57 .is_contiguous()
58 .then(|| unsafe { self.contiguous_unchecked() })
59 }
60
61 pub fn as_contiguous_mut(&self) -> Option<BorrowedValueMut<[u8]>> {
62 (!self.desc.readonly && self.desc.is_contiguous())
63 .then(|| unsafe { self.contiguous_mut_unchecked() })
64 }
65
66 pub fn from_byte_vector(bytes: Vec<u8>, vm: &VirtualMachine) -> Self {
67 let bytes_len = bytes.len();
68 PyBuffer::new(
69 PyPayload::into_pyobject(VecBuffer::from(bytes), vm),
70 BufferDescriptor::simple(bytes_len, true),
71 &VEC_BUFFER_METHODS,
72 )
73 }
74
75 pub unsafe fn contiguous_unchecked(&self) -> BorrowedValue<[u8]> {
78 self.obj_bytes()
79 }
80
81 pub unsafe fn contiguous_mut_unchecked(&self) -> BorrowedValueMut<[u8]> {
84 self.obj_bytes_mut()
85 }
86
87 pub fn append_to(&self, buf: &mut Vec<u8>) {
88 if let Some(bytes) = self.as_contiguous() {
89 buf.extend_from_slice(&bytes);
90 } else {
91 let bytes = &*self.obj_bytes();
92 self.desc.for_each_segment(true, |range| {
93 buf.extend_from_slice(&bytes[range.start as usize..range.end as usize])
94 });
95 }
96 }
97
98 pub fn contiguous_or_collect<R, F: FnOnce(&[u8]) -> R>(&self, f: F) -> R {
99 let borrowed;
100 let mut collected;
101 let v = if let Some(bytes) = self.as_contiguous() {
102 borrowed = bytes;
103 &*borrowed
104 } else {
105 collected = vec![];
106 self.append_to(&mut collected);
107 &collected
108 };
109 f(v)
110 }
111
112 pub fn obj_as<T: PyObjectPayload>(&self) -> &Py<T> {
113 unsafe { self.obj.downcast_unchecked_ref() }
114 }
115
116 pub fn obj_bytes(&self) -> BorrowedValue<[u8]> {
117 (self.methods.obj_bytes)(self)
118 }
119
120 pub fn obj_bytes_mut(&self) -> BorrowedValueMut<[u8]> {
121 (self.methods.obj_bytes_mut)(self)
122 }
123
124 pub fn release(&self) {
125 (self.methods.release)(self)
126 }
127
128 pub fn retain(&self) {
129 (self.methods.retain)(self)
130 }
131
132 pub(crate) unsafe fn drop_without_release(&mut self) {
136 std::ptr::drop_in_place(&mut self.obj);
137 std::ptr::drop_in_place(&mut self.desc);
138 }
139}
140
141impl<'a> TryFromBorrowedObject<'a> for PyBuffer {
142 fn try_from_borrowed_object(vm: &VirtualMachine, obj: &'a PyObject) -> PyResult<Self> {
143 let cls = obj.class();
144 let as_buffer = cls.mro_find_map(|cls| cls.slots.as_buffer);
145 if let Some(f) = as_buffer {
146 return f(obj, vm);
147 }
148 Err(vm.new_type_error(format!(
149 "a bytes-like object is required, not '{}'",
150 cls.name()
151 )))
152 }
153}
154
155impl Drop for PyBuffer {
156 fn drop(&mut self) {
157 self.release();
158 }
159}
160
161#[derive(Debug, Clone)]
162pub struct BufferDescriptor {
163 pub len: usize,
166 pub readonly: bool,
167 pub itemsize: usize,
168 pub format: Cow<'static, str>,
169 pub dim_desc: Vec<(usize, isize, isize)>,
171 }
173
174impl BufferDescriptor {
175 pub fn simple(bytes_len: usize, readonly: bool) -> Self {
176 Self {
177 len: bytes_len,
178 readonly,
179 itemsize: 1,
180 format: Cow::Borrowed("B"),
181 dim_desc: vec![(bytes_len, 1, 0)],
182 }
183 }
184
185 pub fn format(
186 bytes_len: usize,
187 readonly: bool,
188 itemsize: usize,
189 format: Cow<'static, str>,
190 ) -> Self {
191 Self {
192 len: bytes_len,
193 readonly,
194 itemsize,
195 format,
196 dim_desc: vec![(bytes_len / itemsize, itemsize as isize, 0)],
197 }
198 }
199
200 #[cfg(debug_assertions)]
201 pub fn validate(self) -> Self {
202 assert!(self.itemsize != 0);
203 assert!(self.ndim() != 0);
204 let mut shape_product = 1;
205 for (shape, stride, suboffset) in self.dim_desc.iter().cloned() {
206 shape_product *= shape;
207 assert!(suboffset >= 0);
208 assert!(stride != 0);
209 }
210 assert!(shape_product * self.itemsize == self.len);
211 self
212 }
213
214 #[cfg(not(debug_assertions))]
215 pub fn validate(self) -> Self {
216 self
217 }
218
219 pub fn ndim(&self) -> usize {
220 self.dim_desc.len()
221 }
222
223 pub fn is_contiguous(&self) -> bool {
224 if self.len == 0 {
225 return true;
226 }
227 let mut sd = self.itemsize;
228 for (shape, stride, _) in self.dim_desc.iter().cloned().rev() {
229 if shape > 1 && stride != sd as isize {
230 return false;
231 }
232 sd *= shape;
233 }
234 true
235 }
236
237 pub fn fast_position(&self, indices: &[usize]) -> isize {
240 let mut pos = 0;
241 for (i, (_, stride, suboffset)) in indices
242 .iter()
243 .cloned()
244 .zip_eq(self.dim_desc.iter().cloned())
245 {
246 pos += i as isize * stride + suboffset;
247 }
248 pos
249 }
250
251 pub fn position(&self, indices: &[isize], vm: &VirtualMachine) -> PyResult<isize> {
253 let mut pos = 0;
254 for (i, (shape, stride, suboffset)) in indices
255 .iter()
256 .cloned()
257 .zip_eq(self.dim_desc.iter().cloned())
258 {
259 let i = i.wrapped_at(shape).ok_or_else(|| {
260 vm.new_index_error(format!("index out of bounds on dimension {i}"))
261 })?;
262 pos += i as isize * stride + suboffset;
263 }
264 Ok(pos)
265 }
266
267 pub fn for_each_segment<F>(&self, try_conti: bool, mut f: F)
268 where
269 F: FnMut(Range<isize>),
270 {
271 if self.ndim() == 0 {
272 f(0..self.itemsize as isize);
273 return;
274 }
275 if try_conti && self.is_last_dim_contiguous() {
276 self._for_each_segment::<_, true>(0, 0, &mut f);
277 } else {
278 self._for_each_segment::<_, false>(0, 0, &mut f);
279 }
280 }
281
282 fn _for_each_segment<F, const CONTI: bool>(&self, mut index: isize, dim: usize, f: &mut F)
283 where
284 F: FnMut(Range<isize>),
285 {
286 let (shape, stride, suboffset) = self.dim_desc[dim];
287 if dim + 1 == self.ndim() {
288 if CONTI {
289 f(index..index + (shape * self.itemsize) as isize);
290 } else {
291 for _ in 0..shape {
292 let pos = index + suboffset;
293 f(pos..pos + self.itemsize as isize);
294 index += stride;
295 }
296 }
297 return;
298 }
299 for _ in 0..shape {
300 self._for_each_segment::<F, CONTI>(index + suboffset, dim + 1, f);
301 index += stride;
302 }
303 }
304
305 pub fn zip_eq<F>(&self, other: &Self, try_conti: bool, mut f: F)
307 where
308 F: FnMut(Range<isize>, Range<isize>) -> bool,
309 {
310 if self.ndim() == 0 {
311 f(0..self.itemsize as isize, 0..other.itemsize as isize);
312 return;
313 }
314 if try_conti && self.is_last_dim_contiguous() {
315 self._zip_eq::<_, true>(other, 0, 0, 0, &mut f);
316 } else {
317 self._zip_eq::<_, false>(other, 0, 0, 0, &mut f);
318 }
319 }
320
321 fn _zip_eq<F, const CONTI: bool>(
322 &self,
323 other: &Self,
324 mut a_index: isize,
325 mut b_index: isize,
326 dim: usize,
327 f: &mut F,
328 ) where
329 F: FnMut(Range<isize>, Range<isize>) -> bool,
330 {
331 let (shape, a_stride, a_suboffset) = self.dim_desc[dim];
332 let (_b_shape, b_stride, b_suboffset) = other.dim_desc[dim];
333 debug_assert_eq!(shape, _b_shape);
334 if dim + 1 == self.ndim() {
335 if CONTI {
336 if f(
337 a_index..a_index + (shape * self.itemsize) as isize,
338 b_index..b_index + (shape * other.itemsize) as isize,
339 ) {
340 return;
341 }
342 } else {
343 for _ in 0..shape {
344 let a_pos = a_index + a_suboffset;
345 let b_pos = b_index + b_suboffset;
346 if f(
347 a_pos..a_pos + self.itemsize as isize,
348 b_pos..b_pos + other.itemsize as isize,
349 ) {
350 return;
351 }
352 a_index += a_stride;
353 b_index += b_stride;
354 }
355 }
356 return;
357 }
358
359 for _ in 0..shape {
360 self._zip_eq::<F, CONTI>(
361 other,
362 a_index + a_suboffset,
363 b_index + b_suboffset,
364 dim + 1,
365 f,
366 );
367 a_index += a_stride;
368 b_index += b_stride;
369 }
370 }
371
372 fn is_last_dim_contiguous(&self) -> bool {
373 let (_, stride, suboffset) = self.dim_desc[self.ndim() - 1];
374 suboffset == 0 && stride == self.itemsize as isize
375 }
376
377 pub fn is_zero_in_shape(&self) -> bool {
378 for (shape, _, _) in self.dim_desc.iter().cloned() {
379 if shape == 0 {
380 return true;
381 }
382 }
383 false
384 }
385
386 }
388
389pub trait BufferResizeGuard {
390 type Resizable<'a>: 'a
391 where
392 Self: 'a;
393 fn try_resizable_opt(&self) -> Option<Self::Resizable<'_>>;
394 fn try_resizable(&self, vm: &VirtualMachine) -> PyResult<Self::Resizable<'_>> {
395 self.try_resizable_opt().ok_or_else(|| {
396 vm.new_buffer_error("Existing exports of data: object cannot be re-sized".to_owned())
397 })
398 }
399}
400
401#[pyclass(module = false, name = "vec_buffer")]
402#[derive(Debug, PyPayload)]
403pub struct VecBuffer {
404 data: PyMutex<Vec<u8>>,
405}
406
407#[pyclass(flags(BASETYPE), with(Unconstructible))]
408impl VecBuffer {
409 pub fn take(&self) -> Vec<u8> {
410 std::mem::take(&mut self.data.lock())
411 }
412}
413
414impl From<Vec<u8>> for VecBuffer {
415 fn from(data: Vec<u8>) -> Self {
416 Self {
417 data: PyMutex::new(data),
418 }
419 }
420}
421
422impl Unconstructible for VecBuffer {}
423
424impl PyRef<VecBuffer> {
425 pub fn into_pybuffer(self, readonly: bool) -> PyBuffer {
426 let len = self.data.lock().len();
427 PyBuffer::new(
428 self.into(),
429 BufferDescriptor::simple(len, readonly),
430 &VEC_BUFFER_METHODS,
431 )
432 }
433
434 pub fn into_pybuffer_with_descriptor(self, desc: BufferDescriptor) -> PyBuffer {
435 PyBuffer::new(self.into(), desc, &VEC_BUFFER_METHODS)
436 }
437}
438
439static VEC_BUFFER_METHODS: BufferMethods = BufferMethods {
440 obj_bytes: |buffer| {
441 PyMutexGuard::map_immutable(buffer.obj_as::<VecBuffer>().data.lock(), |x| x.as_slice())
442 .into()
443 },
444 obj_bytes_mut: |buffer| {
445 PyMutexGuard::map(buffer.obj_as::<VecBuffer>().data.lock(), |x| {
446 x.as_mut_slice()
447 })
448 .into()
449 },
450 release: |_| {},
451 retain: |_| {},
452};