rustpython_vm/protocol/
sequence.rs

1use crate::{
2    builtins::{type_::PointerSlot, PyList, PyListRef, PySlice, PyTuple, PyTupleRef},
3    convert::ToPyObject,
4    function::PyArithmeticValue,
5    object::{Traverse, TraverseFn},
6    protocol::{PyMapping, PyNumberBinaryOp},
7    PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
8};
9use crossbeam_utils::atomic::AtomicCell;
10use itertools::Itertools;
11use std::fmt::Debug;
12
13// Sequence Protocol
14// https://docs.python.org/3/c-api/sequence.html
15
16impl PyObject {
17    #[inline]
18    pub fn to_sequence(&self) -> PySequence<'_> {
19        static GLOBAL_NOT_IMPLEMENTED: PySequenceMethods = PySequenceMethods::NOT_IMPLEMENTED;
20        PySequence {
21            obj: self,
22            methods: PySequence::find_methods(self)
23                .map_or(&GLOBAL_NOT_IMPLEMENTED, |x| unsafe { x.borrow_static() }),
24        }
25    }
26}
27
28#[allow(clippy::type_complexity)]
29#[derive(Default)]
30pub struct PySequenceMethods {
31    pub length: AtomicCell<Option<fn(PySequence, &VirtualMachine) -> PyResult<usize>>>,
32    pub concat: AtomicCell<Option<fn(PySequence, &PyObject, &VirtualMachine) -> PyResult>>,
33    pub repeat: AtomicCell<Option<fn(PySequence, isize, &VirtualMachine) -> PyResult>>,
34    pub item: AtomicCell<Option<fn(PySequence, isize, &VirtualMachine) -> PyResult>>,
35    pub ass_item: AtomicCell<
36        Option<fn(PySequence, isize, Option<PyObjectRef>, &VirtualMachine) -> PyResult<()>>,
37    >,
38    pub contains: AtomicCell<Option<fn(PySequence, &PyObject, &VirtualMachine) -> PyResult<bool>>>,
39    pub inplace_concat: AtomicCell<Option<fn(PySequence, &PyObject, &VirtualMachine) -> PyResult>>,
40    pub inplace_repeat: AtomicCell<Option<fn(PySequence, isize, &VirtualMachine) -> PyResult>>,
41}
42
43impl Debug for PySequenceMethods {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        write!(f, "Sequence Methods")
46    }
47}
48
49impl PySequenceMethods {
50    #[allow(clippy::declare_interior_mutable_const)]
51    pub const NOT_IMPLEMENTED: PySequenceMethods = PySequenceMethods {
52        length: AtomicCell::new(None),
53        concat: AtomicCell::new(None),
54        repeat: AtomicCell::new(None),
55        item: AtomicCell::new(None),
56        ass_item: AtomicCell::new(None),
57        contains: AtomicCell::new(None),
58        inplace_concat: AtomicCell::new(None),
59        inplace_repeat: AtomicCell::new(None),
60    };
61}
62
63#[derive(Copy, Clone)]
64pub struct PySequence<'a> {
65    pub obj: &'a PyObject,
66    pub methods: &'static PySequenceMethods,
67}
68
69unsafe impl Traverse for PySequence<'_> {
70    fn traverse(&self, tracer_fn: &mut TraverseFn) {
71        self.obj.traverse(tracer_fn)
72    }
73}
74
75impl<'a> PySequence<'a> {
76    #[inline]
77    pub fn with_methods(obj: &'a PyObject, methods: &'static PySequenceMethods) -> Self {
78        Self { obj, methods }
79    }
80
81    pub fn try_protocol(obj: &'a PyObject, vm: &VirtualMachine) -> PyResult<Self> {
82        let seq = obj.to_sequence();
83        if seq.check() {
84            Ok(seq)
85        } else {
86            Err(vm.new_type_error(format!("'{}' is not a sequence", obj.class())))
87        }
88    }
89}
90
91impl PySequence<'_> {
92    pub fn check(&self) -> bool {
93        self.methods.item.load().is_some()
94    }
95
96    pub fn find_methods(obj: &PyObject) -> Option<PointerSlot<PySequenceMethods>> {
97        let cls = obj.class();
98        cls.mro_find_map(|x| x.slots.as_sequence.load())
99    }
100
101    pub fn length_opt(self, vm: &VirtualMachine) -> Option<PyResult<usize>> {
102        self.methods.length.load().map(|f| f(self, vm))
103    }
104
105    pub fn length(self, vm: &VirtualMachine) -> PyResult<usize> {
106        self.length_opt(vm).ok_or_else(|| {
107            vm.new_type_error(format!(
108                "'{}' is not a sequence or has no len()",
109                self.obj.class()
110            ))
111        })?
112    }
113
114    pub fn concat(self, other: &PyObject, vm: &VirtualMachine) -> PyResult {
115        if let Some(f) = self.methods.concat.load() {
116            return f(self, other, vm);
117        }
118
119        // if both arguments apear to be sequences, try fallback to __add__
120        if self.check() && other.to_sequence().check() {
121            let ret = vm.binary_op1(self.obj, other, PyNumberBinaryOp::Add)?;
122            if let PyArithmeticValue::Implemented(ret) = PyArithmeticValue::from_object(vm, ret) {
123                return Ok(ret);
124            }
125        }
126
127        Err(vm.new_type_error(format!(
128            "'{}' object can't be concatenated",
129            self.obj.class()
130        )))
131    }
132
133    pub fn repeat(self, n: isize, vm: &VirtualMachine) -> PyResult {
134        if let Some(f) = self.methods.repeat.load() {
135            return f(self, n, vm);
136        }
137
138        // fallback to __mul__
139        if self.check() {
140            let ret = vm.binary_op1(self.obj, &n.to_pyobject(vm), PyNumberBinaryOp::Multiply)?;
141            if let PyArithmeticValue::Implemented(ret) = PyArithmeticValue::from_object(vm, ret) {
142                return Ok(ret);
143            }
144        }
145
146        Err(vm.new_type_error(format!("'{}' object can't be repeated", self.obj.class())))
147    }
148
149    pub fn inplace_concat(self, other: &PyObject, vm: &VirtualMachine) -> PyResult {
150        if let Some(f) = self.methods.inplace_concat.load() {
151            return f(self, other, vm);
152        }
153        if let Some(f) = self.methods.concat.load() {
154            return f(self, other, vm);
155        }
156
157        // if both arguments apear to be sequences, try fallback to __iadd__
158        if self.check() && other.to_sequence().check() {
159            let ret = vm._iadd(self.obj, other)?;
160            if let PyArithmeticValue::Implemented(ret) = PyArithmeticValue::from_object(vm, ret) {
161                return Ok(ret);
162            }
163        }
164
165        Err(vm.new_type_error(format!(
166            "'{}' object can't be concatenated",
167            self.obj.class()
168        )))
169    }
170
171    pub fn inplace_repeat(self, n: isize, vm: &VirtualMachine) -> PyResult {
172        if let Some(f) = self.methods.inplace_repeat.load() {
173            return f(self, n, vm);
174        }
175        if let Some(f) = self.methods.repeat.load() {
176            return f(self, n, vm);
177        }
178
179        if self.check() {
180            let ret = vm._imul(self.obj, &n.to_pyobject(vm))?;
181            if let PyArithmeticValue::Implemented(ret) = PyArithmeticValue::from_object(vm, ret) {
182                return Ok(ret);
183            }
184        }
185
186        Err(vm.new_type_error(format!("'{}' object can't be repeated", self.obj.class())))
187    }
188
189    pub fn get_item(self, i: isize, vm: &VirtualMachine) -> PyResult {
190        if let Some(f) = self.methods.item.load() {
191            return f(self, i, vm);
192        }
193        Err(vm.new_type_error(format!(
194            "'{}' is not a sequence or does not support indexing",
195            self.obj.class()
196        )))
197    }
198
199    fn _ass_item(self, i: isize, value: Option<PyObjectRef>, vm: &VirtualMachine) -> PyResult<()> {
200        if let Some(f) = self.methods.ass_item.load() {
201            return f(self, i, value, vm);
202        }
203        Err(vm.new_type_error(format!(
204            "'{}' is not a sequence or doesn't support item {}",
205            self.obj.class(),
206            if value.is_some() {
207                "assignment"
208            } else {
209                "deletion"
210            }
211        )))
212    }
213
214    pub fn set_item(self, i: isize, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
215        self._ass_item(i, Some(value), vm)
216    }
217
218    pub fn del_item(self, i: isize, vm: &VirtualMachine) -> PyResult<()> {
219        self._ass_item(i, None, vm)
220    }
221
222    pub fn get_slice(&self, start: isize, stop: isize, vm: &VirtualMachine) -> PyResult {
223        if let Ok(mapping) = PyMapping::try_protocol(self.obj, vm) {
224            let slice = PySlice {
225                start: Some(start.to_pyobject(vm)),
226                stop: stop.to_pyobject(vm),
227                step: None,
228            };
229            mapping.subscript(&slice.into_pyobject(vm), vm)
230        } else {
231            Err(vm.new_type_error(format!("'{}' object is unsliceable", self.obj.class())))
232        }
233    }
234
235    fn _ass_slice(
236        &self,
237        start: isize,
238        stop: isize,
239        value: Option<PyObjectRef>,
240        vm: &VirtualMachine,
241    ) -> PyResult<()> {
242        let mapping = self.obj.to_mapping();
243        if let Some(f) = mapping.methods.ass_subscript.load() {
244            let slice = PySlice {
245                start: Some(start.to_pyobject(vm)),
246                stop: stop.to_pyobject(vm),
247                step: None,
248            };
249            f(mapping, &slice.into_pyobject(vm), value, vm)
250        } else {
251            Err(vm.new_type_error(format!(
252                "'{}' object doesn't support slice {}",
253                self.obj.class(),
254                if value.is_some() {
255                    "assignment"
256                } else {
257                    "deletion"
258                }
259            )))
260        }
261    }
262
263    pub fn set_slice(
264        &self,
265        start: isize,
266        stop: isize,
267        value: PyObjectRef,
268        vm: &VirtualMachine,
269    ) -> PyResult<()> {
270        self._ass_slice(start, stop, Some(value), vm)
271    }
272
273    pub fn del_slice(&self, start: isize, stop: isize, vm: &VirtualMachine) -> PyResult<()> {
274        self._ass_slice(start, stop, None, vm)
275    }
276
277    pub fn tuple(&self, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
278        if let Some(tuple) = self.obj.downcast_ref_if_exact::<PyTuple>(vm) {
279            Ok(tuple.to_owned())
280        } else if let Some(list) = self.obj.downcast_ref_if_exact::<PyList>(vm) {
281            Ok(vm.ctx.new_tuple(list.borrow_vec().to_vec()))
282        } else {
283            let iter = self.obj.to_owned().get_iter(vm)?;
284            let iter = iter.iter(vm)?;
285            Ok(vm.ctx.new_tuple(iter.try_collect()?))
286        }
287    }
288
289    pub fn list(&self, vm: &VirtualMachine) -> PyResult<PyListRef> {
290        let list = vm.ctx.new_list(self.obj.try_to_value(vm)?);
291        Ok(list)
292    }
293
294    pub fn count(&self, target: &PyObject, vm: &VirtualMachine) -> PyResult<usize> {
295        let mut n = 0;
296
297        let iter = self.obj.to_owned().get_iter(vm)?;
298        let iter = iter.iter::<PyObjectRef>(vm)?;
299
300        for elem in iter {
301            let elem = elem?;
302            if vm.bool_eq(&elem, target)? {
303                if n == isize::MAX as usize {
304                    return Err(vm.new_overflow_error("index exceeds C integer size".to_string()));
305                }
306                n += 1;
307            }
308        }
309
310        Ok(n)
311    }
312
313    pub fn index(&self, target: &PyObject, vm: &VirtualMachine) -> PyResult<usize> {
314        let mut index: isize = -1;
315
316        let iter = self.obj.to_owned().get_iter(vm)?;
317        let iter = iter.iter::<PyObjectRef>(vm)?;
318
319        for elem in iter {
320            if index == isize::MAX {
321                return Err(vm.new_overflow_error("index exceeds C integer size".to_string()));
322            }
323            index += 1;
324
325            let elem = elem?;
326            if vm.bool_eq(&elem, target)? {
327                return Ok(index as usize);
328            }
329        }
330
331        Err(vm.new_value_error("sequence.index(x): x not in sequence".to_string()))
332    }
333
334    pub fn extract<F, R>(&self, mut f: F, vm: &VirtualMachine) -> PyResult<Vec<R>>
335    where
336        F: FnMut(&PyObject) -> PyResult<R>,
337    {
338        if let Some(tuple) = self.obj.payload_if_exact::<PyTuple>(vm) {
339            tuple.iter().map(|x| f(x.as_ref())).collect()
340        } else if let Some(list) = self.obj.payload_if_exact::<PyList>(vm) {
341            list.borrow_vec().iter().map(|x| f(x.as_ref())).collect()
342        } else {
343            let iter = self.obj.to_owned().get_iter(vm)?;
344            let iter = iter.iter::<PyObjectRef>(vm)?;
345            let len = self.length(vm).unwrap_or(0);
346            let mut v = Vec::with_capacity(len);
347            for x in iter {
348                v.push(f(x?.as_ref())?);
349            }
350            v.shrink_to_fit();
351            Ok(v)
352        }
353    }
354
355    pub fn contains(self, target: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
356        if let Some(f) = self.methods.contains.load() {
357            return f(self, target, vm);
358        }
359
360        let iter = self.obj.to_owned().get_iter(vm)?;
361        let iter = iter.iter::<PyObjectRef>(vm)?;
362
363        for elem in iter {
364            let elem = elem?;
365            if vm.bool_eq(&elem, target)? {
366                return Ok(true);
367            }
368        }
369        Ok(false)
370    }
371}
Morty Proxy This is a proxified and sanitized view of the page, visit original site.