rustpython_vm/builtins/
enumerate.rs1use super::{
2 IterStatus, PositionIterInternal, PyGenericAlias, PyIntRef, PyTupleRef, PyType, PyTypeRef,
3};
4use crate::common::lock::{PyMutex, PyRwLock};
5use crate::{
6 class::PyClassImpl,
7 convert::ToPyObject,
8 function::OptionalArg,
9 protocol::{PyIter, PyIterReturn},
10 types::{Constructor, IterNext, Iterable, SelfIter},
11 AsObject, Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine,
12};
13use malachite_bigint::BigInt;
14use num_traits::Zero;
15
16#[pyclass(module = false, name = "enumerate", traverse)]
17#[derive(Debug)]
18pub struct PyEnumerate {
19 #[pytraverse(skip)]
20 counter: PyRwLock<BigInt>,
21 iterator: PyIter,
22}
23
24impl PyPayload for PyEnumerate {
25 fn class(ctx: &Context) -> &'static Py<PyType> {
26 ctx.types.enumerate_type
27 }
28}
29
30#[derive(FromArgs)]
31pub struct EnumerateArgs {
32 iterator: PyIter,
33 #[pyarg(any, optional)]
34 start: OptionalArg<PyIntRef>,
35}
36
37impl Constructor for PyEnumerate {
38 type Args = EnumerateArgs;
39
40 fn py_new(
41 cls: PyTypeRef,
42 Self::Args { iterator, start }: Self::Args,
43 vm: &VirtualMachine,
44 ) -> PyResult {
45 let counter = start.map_or_else(BigInt::zero, |start| start.as_bigint().clone());
46 PyEnumerate {
47 counter: PyRwLock::new(counter),
48 iterator,
49 }
50 .into_ref_with_type(vm, cls)
51 .map(Into::into)
52 }
53}
54
55#[pyclass(with(Py, IterNext, Iterable, Constructor), flags(BASETYPE))]
56impl PyEnumerate {
57 #[pyclassmethod(magic)]
58 fn class_getitem(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
59 PyGenericAlias::new(cls, args, vm)
60 }
61}
62
63#[pyclass]
64impl Py<PyEnumerate> {
65 #[pymethod(magic)]
66 fn reduce(&self) -> (PyTypeRef, (PyIter, BigInt)) {
67 (
68 self.class().to_owned(),
69 (self.iterator.clone(), self.counter.read().clone()),
70 )
71 }
72}
73
74impl SelfIter for PyEnumerate {}
75impl IterNext for PyEnumerate {
76 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
77 let next_obj = match zelf.iterator.next(vm)? {
78 PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)),
79 PyIterReturn::Return(obj) => obj,
80 };
81 let mut counter = zelf.counter.write();
82 let position = counter.clone();
83 *counter += 1;
84 Ok(PyIterReturn::Return((position, next_obj).to_pyobject(vm)))
85 }
86}
87
88#[pyclass(module = false, name = "reversed", traverse)]
89#[derive(Debug)]
90pub struct PyReverseSequenceIterator {
91 internal: PyMutex<PositionIterInternal<PyObjectRef>>,
92}
93
94impl PyPayload for PyReverseSequenceIterator {
95 fn class(ctx: &Context) -> &'static Py<PyType> {
96 ctx.types.reverse_iter_type
97 }
98}
99
100#[pyclass(with(IterNext, Iterable))]
101impl PyReverseSequenceIterator {
102 pub fn new(obj: PyObjectRef, len: usize) -> Self {
103 let position = len.saturating_sub(1);
104 Self {
105 internal: PyMutex::new(PositionIterInternal::new(obj, position)),
106 }
107 }
108
109 #[pymethod(magic)]
110 fn length_hint(&self, vm: &VirtualMachine) -> PyResult<usize> {
111 let internal = self.internal.lock();
112 if let IterStatus::Active(obj) = &internal.status {
113 if internal.position <= obj.length(vm)? {
114 return Ok(internal.position + 1);
115 }
116 }
117 Ok(0)
118 }
119
120 #[pymethod(magic)]
121 fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
122 self.internal.lock().set_state(state, |_, pos| pos, vm)
123 }
124
125 #[pymethod(magic)]
126 fn reduce(&self, vm: &VirtualMachine) -> PyTupleRef {
127 self.internal
128 .lock()
129 .builtins_reversed_reduce(|x| x.clone(), vm)
130 }
131}
132
133impl SelfIter for PyReverseSequenceIterator {}
134impl IterNext for PyReverseSequenceIterator {
135 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
136 zelf.internal
137 .lock()
138 .rev_next(|obj, pos| PyIterReturn::from_getitem_result(obj.get_item(&pos, vm), vm))
139 }
140}
141
142pub fn init(context: &Context) {
143 PyEnumerate::extend_class(context, context.types.enumerate_type);
144 PyReverseSequenceIterator::extend_class(context, context.types.reverse_iter_type);
145}