rustpython_vm/builtins/
iter.rs

1/*
2 * iterator types
3 */
4
5use super::{PyInt, PyTupleRef, PyType};
6use crate::{
7    class::PyClassImpl,
8    function::ArgCallable,
9    object::{Traverse, TraverseFn},
10    protocol::{PyIterReturn, PySequence, PySequenceMethods},
11    types::{IterNext, Iterable, SelfIter},
12    Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
13};
14use rustpython_common::{
15    lock::{PyMutex, PyRwLock, PyRwLockUpgradableReadGuard},
16    static_cell,
17};
18
19/// Marks status of iterator.
20#[derive(Debug, Clone)]
21pub enum IterStatus<T> {
22    /// Iterator hasn't raised StopIteration.
23    Active(T),
24    /// Iterator has raised StopIteration.
25    Exhausted,
26}
27
28unsafe impl<T: Traverse> Traverse for IterStatus<T> {
29    fn traverse(&self, tracer_fn: &mut TraverseFn) {
30        match self {
31            IterStatus::Active(ref r) => r.traverse(tracer_fn),
32            IterStatus::Exhausted => (),
33        }
34    }
35}
36
37#[derive(Debug)]
38pub struct PositionIterInternal<T> {
39    pub status: IterStatus<T>,
40    pub position: usize,
41}
42
43unsafe impl<T: Traverse> Traverse for PositionIterInternal<T> {
44    fn traverse(&self, tracer_fn: &mut TraverseFn) {
45        self.status.traverse(tracer_fn)
46    }
47}
48
49impl<T> PositionIterInternal<T> {
50    pub fn new(obj: T, position: usize) -> Self {
51        Self {
52            status: IterStatus::Active(obj),
53            position,
54        }
55    }
56
57    pub fn set_state<F>(&mut self, state: PyObjectRef, f: F, vm: &VirtualMachine) -> PyResult<()>
58    where
59        F: FnOnce(&T, usize) -> usize,
60    {
61        if let IterStatus::Active(obj) = &self.status {
62            if let Some(i) = state.payload::<PyInt>() {
63                let i = i.try_to_primitive(vm).unwrap_or(0);
64                self.position = f(obj, i);
65                Ok(())
66            } else {
67                Err(vm.new_type_error("an integer is required.".to_owned()))
68            }
69        } else {
70            Ok(())
71        }
72    }
73
74    fn _reduce<F>(&self, func: PyObjectRef, f: F, vm: &VirtualMachine) -> PyTupleRef
75    where
76        F: FnOnce(&T) -> PyObjectRef,
77    {
78        if let IterStatus::Active(obj) = &self.status {
79            vm.new_tuple((func, (f(obj),), self.position))
80        } else {
81            vm.new_tuple((func, (vm.ctx.new_list(Vec::new()),)))
82        }
83    }
84
85    pub fn builtins_iter_reduce<F>(&self, f: F, vm: &VirtualMachine) -> PyTupleRef
86    where
87        F: FnOnce(&T) -> PyObjectRef,
88    {
89        let iter = builtins_iter(vm).to_owned();
90        self._reduce(iter, f, vm)
91    }
92
93    pub fn builtins_reversed_reduce<F>(&self, f: F, vm: &VirtualMachine) -> PyTupleRef
94    where
95        F: FnOnce(&T) -> PyObjectRef,
96    {
97        let reversed = builtins_reversed(vm).to_owned();
98        self._reduce(reversed, f, vm)
99    }
100
101    fn _next<F, OP>(&mut self, f: F, op: OP) -> PyResult<PyIterReturn>
102    where
103        F: FnOnce(&T, usize) -> PyResult<PyIterReturn>,
104        OP: FnOnce(&mut Self),
105    {
106        if let IterStatus::Active(obj) = &self.status {
107            let ret = f(obj, self.position);
108            if let Ok(PyIterReturn::Return(_)) = ret {
109                op(self);
110            } else {
111                self.status = IterStatus::Exhausted;
112            }
113            ret
114        } else {
115            Ok(PyIterReturn::StopIteration(None))
116        }
117    }
118
119    pub fn next<F>(&mut self, f: F) -> PyResult<PyIterReturn>
120    where
121        F: FnOnce(&T, usize) -> PyResult<PyIterReturn>,
122    {
123        self._next(f, |zelf| zelf.position += 1)
124    }
125
126    pub fn rev_next<F>(&mut self, f: F) -> PyResult<PyIterReturn>
127    where
128        F: FnOnce(&T, usize) -> PyResult<PyIterReturn>,
129    {
130        self._next(f, |zelf| {
131            if zelf.position == 0 {
132                zelf.status = IterStatus::Exhausted;
133            } else {
134                zelf.position -= 1;
135            }
136        })
137    }
138
139    pub fn length_hint<F>(&self, f: F) -> usize
140    where
141        F: FnOnce(&T) -> usize,
142    {
143        if let IterStatus::Active(obj) = &self.status {
144            f(obj).saturating_sub(self.position)
145        } else {
146            0
147        }
148    }
149
150    pub fn rev_length_hint<F>(&self, f: F) -> usize
151    where
152        F: FnOnce(&T) -> usize,
153    {
154        if let IterStatus::Active(obj) = &self.status {
155            if self.position <= f(obj) {
156                return self.position + 1;
157            }
158        }
159        0
160    }
161}
162
163pub fn builtins_iter(vm: &VirtualMachine) -> &PyObject {
164    static_cell! {
165        static INSTANCE: PyObjectRef;
166    }
167    INSTANCE.get_or_init(|| vm.builtins.get_attr("iter", vm).unwrap())
168}
169
170pub fn builtins_reversed(vm: &VirtualMachine) -> &PyObject {
171    static_cell! {
172        static INSTANCE: PyObjectRef;
173    }
174    INSTANCE.get_or_init(|| vm.builtins.get_attr("reversed", vm).unwrap())
175}
176
177#[pyclass(module = false, name = "iterator", traverse)]
178#[derive(Debug)]
179pub struct PySequenceIterator {
180    // cached sequence methods
181    #[pytraverse(skip)]
182    seq_methods: &'static PySequenceMethods,
183    internal: PyMutex<PositionIterInternal<PyObjectRef>>,
184}
185
186impl PyPayload for PySequenceIterator {
187    fn class(ctx: &Context) -> &'static Py<PyType> {
188        ctx.types.iter_type
189    }
190}
191
192#[pyclass(with(IterNext, Iterable))]
193impl PySequenceIterator {
194    pub fn new(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<Self> {
195        let seq = PySequence::try_protocol(&obj, vm)?;
196        Ok(Self {
197            seq_methods: seq.methods,
198            internal: PyMutex::new(PositionIterInternal::new(obj, 0)),
199        })
200    }
201
202    #[pymethod(magic)]
203    fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef {
204        let internal = self.internal.lock();
205        if let IterStatus::Active(obj) = &internal.status {
206            let seq = PySequence {
207                obj,
208                methods: self.seq_methods,
209            };
210            seq.length(vm)
211                .map(|x| PyInt::from(x).into_pyobject(vm))
212                .unwrap_or_else(|_| vm.ctx.not_implemented())
213        } else {
214            PyInt::from(0).into_pyobject(vm)
215        }
216    }
217
218    #[pymethod(magic)]
219    fn reduce(&self, vm: &VirtualMachine) -> PyTupleRef {
220        self.internal.lock().builtins_iter_reduce(|x| x.clone(), vm)
221    }
222
223    #[pymethod(magic)]
224    fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
225        self.internal.lock().set_state(state, |_, pos| pos, vm)
226    }
227}
228
229impl SelfIter for PySequenceIterator {}
230impl IterNext for PySequenceIterator {
231    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
232        zelf.internal.lock().next(|obj, pos| {
233            let seq = PySequence {
234                obj,
235                methods: zelf.seq_methods,
236            };
237            PyIterReturn::from_getitem_result(seq.get_item(pos as isize, vm), vm)
238        })
239    }
240}
241
242#[pyclass(module = false, name = "callable_iterator", traverse)]
243#[derive(Debug)]
244pub struct PyCallableIterator {
245    sentinel: PyObjectRef,
246    status: PyRwLock<IterStatus<ArgCallable>>,
247}
248
249impl PyPayload for PyCallableIterator {
250    fn class(ctx: &Context) -> &'static Py<PyType> {
251        ctx.types.callable_iterator
252    }
253}
254
255#[pyclass(with(IterNext, Iterable))]
256impl PyCallableIterator {
257    pub fn new(callable: ArgCallable, sentinel: PyObjectRef) -> Self {
258        Self {
259            sentinel,
260            status: PyRwLock::new(IterStatus::Active(callable)),
261        }
262    }
263}
264
265impl SelfIter for PyCallableIterator {}
266impl IterNext for PyCallableIterator {
267    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
268        let status = zelf.status.upgradable_read();
269        let next = if let IterStatus::Active(callable) = &*status {
270            let ret = callable.invoke((), vm)?;
271            if vm.bool_eq(&ret, &zelf.sentinel)? {
272                *PyRwLockUpgradableReadGuard::upgrade(status) = IterStatus::Exhausted;
273                PyIterReturn::StopIteration(None)
274            } else {
275                PyIterReturn::Return(ret)
276            }
277        } else {
278            PyIterReturn::StopIteration(None)
279        };
280        Ok(next)
281    }
282}
283
284pub fn init(context: &Context) {
285    PySequenceIterator::extend_class(context, context.types.iter_type);
286    PyCallableIterator::extend_class(context, context.types.callable_iterator);
287}
Morty Proxy This is a proxified and sanitized view of the page, visit original site.