rustpython_vm/builtins/
enumerate.rs

1use super::{
2    IterStatus, PositionIterInternal, PyGenericAlias, PyIntRef, PyTupleRef, PyType, PyTypeRef,
3};
4use crate::common::lock::{PyMutex, PyRwLock};
5use crate::{
6    class::PyClassImpl,
7    convert::ToPyObject,
8    function::OptionalArg,
9    protocol::{PyIter, PyIterReturn},
10    types::{Constructor, IterNext, Iterable, SelfIter},
11    AsObject, Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine,
12};
13use malachite_bigint::BigInt;
14use num_traits::Zero;
15
16#[pyclass(module = false, name = "enumerate", traverse)]
17#[derive(Debug)]
18pub struct PyEnumerate {
19    #[pytraverse(skip)]
20    counter: PyRwLock<BigInt>,
21    iterator: PyIter,
22}
23
24impl PyPayload for PyEnumerate {
25    fn class(ctx: &Context) -> &'static Py<PyType> {
26        ctx.types.enumerate_type
27    }
28}
29
30#[derive(FromArgs)]
31pub struct EnumerateArgs {
32    iterator: PyIter,
33    #[pyarg(any, optional)]
34    start: OptionalArg<PyIntRef>,
35}
36
37impl Constructor for PyEnumerate {
38    type Args = EnumerateArgs;
39
40    fn py_new(
41        cls: PyTypeRef,
42        Self::Args { iterator, start }: Self::Args,
43        vm: &VirtualMachine,
44    ) -> PyResult {
45        let counter = start.map_or_else(BigInt::zero, |start| start.as_bigint().clone());
46        PyEnumerate {
47            counter: PyRwLock::new(counter),
48            iterator,
49        }
50        .into_ref_with_type(vm, cls)
51        .map(Into::into)
52    }
53}
54
55#[pyclass(with(Py, IterNext, Iterable, Constructor), flags(BASETYPE))]
56impl PyEnumerate {
57    #[pyclassmethod(magic)]
58    fn class_getitem(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
59        PyGenericAlias::new(cls, args, vm)
60    }
61}
62
63#[pyclass]
64impl Py<PyEnumerate> {
65    #[pymethod(magic)]
66    fn reduce(&self) -> (PyTypeRef, (PyIter, BigInt)) {
67        (
68            self.class().to_owned(),
69            (self.iterator.clone(), self.counter.read().clone()),
70        )
71    }
72}
73
74impl SelfIter for PyEnumerate {}
75impl IterNext for PyEnumerate {
76    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
77        let next_obj = match zelf.iterator.next(vm)? {
78            PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)),
79            PyIterReturn::Return(obj) => obj,
80        };
81        let mut counter = zelf.counter.write();
82        let position = counter.clone();
83        *counter += 1;
84        Ok(PyIterReturn::Return((position, next_obj).to_pyobject(vm)))
85    }
86}
87
88#[pyclass(module = false, name = "reversed", traverse)]
89#[derive(Debug)]
90pub struct PyReverseSequenceIterator {
91    internal: PyMutex<PositionIterInternal<PyObjectRef>>,
92}
93
94impl PyPayload for PyReverseSequenceIterator {
95    fn class(ctx: &Context) -> &'static Py<PyType> {
96        ctx.types.reverse_iter_type
97    }
98}
99
100#[pyclass(with(IterNext, Iterable))]
101impl PyReverseSequenceIterator {
102    pub fn new(obj: PyObjectRef, len: usize) -> Self {
103        let position = len.saturating_sub(1);
104        Self {
105            internal: PyMutex::new(PositionIterInternal::new(obj, position)),
106        }
107    }
108
109    #[pymethod(magic)]
110    fn length_hint(&self, vm: &VirtualMachine) -> PyResult<usize> {
111        let internal = self.internal.lock();
112        if let IterStatus::Active(obj) = &internal.status {
113            if internal.position <= obj.length(vm)? {
114                return Ok(internal.position + 1);
115            }
116        }
117        Ok(0)
118    }
119
120    #[pymethod(magic)]
121    fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
122        self.internal.lock().set_state(state, |_, pos| pos, vm)
123    }
124
125    #[pymethod(magic)]
126    fn reduce(&self, vm: &VirtualMachine) -> PyTupleRef {
127        self.internal
128            .lock()
129            .builtins_reversed_reduce(|x| x.clone(), vm)
130    }
131}
132
133impl SelfIter for PyReverseSequenceIterator {}
134impl IterNext for PyReverseSequenceIterator {
135    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
136        zelf.internal
137            .lock()
138            .rev_next(|obj, pos| PyIterReturn::from_getitem_result(obj.get_item(&pos, vm), vm))
139    }
140}
141
142pub fn init(context: &Context) {
143    PyEnumerate::extend_class(context, context.types.enumerate_type);
144    PyReverseSequenceIterator::extend_class(context, context.types.reverse_iter_type);
145}
Morty Proxy This is a proxified and sanitized view of the page, visit original site.