1use super::{PyInt, PyTupleRef, PyType};
6use crate::{
7 class::PyClassImpl,
8 function::ArgCallable,
9 object::{Traverse, TraverseFn},
10 protocol::{PyIterReturn, PySequence, PySequenceMethods},
11 types::{IterNext, Iterable, SelfIter},
12 Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
13};
14use rustpython_common::{
15 lock::{PyMutex, PyRwLock, PyRwLockUpgradableReadGuard},
16 static_cell,
17};
18
19#[derive(Debug, Clone)]
21pub enum IterStatus<T> {
22 Active(T),
24 Exhausted,
26}
27
28unsafe impl<T: Traverse> Traverse for IterStatus<T> {
29 fn traverse(&self, tracer_fn: &mut TraverseFn) {
30 match self {
31 IterStatus::Active(ref r) => r.traverse(tracer_fn),
32 IterStatus::Exhausted => (),
33 }
34 }
35}
36
37#[derive(Debug)]
38pub struct PositionIterInternal<T> {
39 pub status: IterStatus<T>,
40 pub position: usize,
41}
42
43unsafe impl<T: Traverse> Traverse for PositionIterInternal<T> {
44 fn traverse(&self, tracer_fn: &mut TraverseFn) {
45 self.status.traverse(tracer_fn)
46 }
47}
48
49impl<T> PositionIterInternal<T> {
50 pub fn new(obj: T, position: usize) -> Self {
51 Self {
52 status: IterStatus::Active(obj),
53 position,
54 }
55 }
56
57 pub fn set_state<F>(&mut self, state: PyObjectRef, f: F, vm: &VirtualMachine) -> PyResult<()>
58 where
59 F: FnOnce(&T, usize) -> usize,
60 {
61 if let IterStatus::Active(obj) = &self.status {
62 if let Some(i) = state.payload::<PyInt>() {
63 let i = i.try_to_primitive(vm).unwrap_or(0);
64 self.position = f(obj, i);
65 Ok(())
66 } else {
67 Err(vm.new_type_error("an integer is required.".to_owned()))
68 }
69 } else {
70 Ok(())
71 }
72 }
73
74 fn _reduce<F>(&self, func: PyObjectRef, f: F, vm: &VirtualMachine) -> PyTupleRef
75 where
76 F: FnOnce(&T) -> PyObjectRef,
77 {
78 if let IterStatus::Active(obj) = &self.status {
79 vm.new_tuple((func, (f(obj),), self.position))
80 } else {
81 vm.new_tuple((func, (vm.ctx.new_list(Vec::new()),)))
82 }
83 }
84
85 pub fn builtins_iter_reduce<F>(&self, f: F, vm: &VirtualMachine) -> PyTupleRef
86 where
87 F: FnOnce(&T) -> PyObjectRef,
88 {
89 let iter = builtins_iter(vm).to_owned();
90 self._reduce(iter, f, vm)
91 }
92
93 pub fn builtins_reversed_reduce<F>(&self, f: F, vm: &VirtualMachine) -> PyTupleRef
94 where
95 F: FnOnce(&T) -> PyObjectRef,
96 {
97 let reversed = builtins_reversed(vm).to_owned();
98 self._reduce(reversed, f, vm)
99 }
100
101 fn _next<F, OP>(&mut self, f: F, op: OP) -> PyResult<PyIterReturn>
102 where
103 F: FnOnce(&T, usize) -> PyResult<PyIterReturn>,
104 OP: FnOnce(&mut Self),
105 {
106 if let IterStatus::Active(obj) = &self.status {
107 let ret = f(obj, self.position);
108 if let Ok(PyIterReturn::Return(_)) = ret {
109 op(self);
110 } else {
111 self.status = IterStatus::Exhausted;
112 }
113 ret
114 } else {
115 Ok(PyIterReturn::StopIteration(None))
116 }
117 }
118
119 pub fn next<F>(&mut self, f: F) -> PyResult<PyIterReturn>
120 where
121 F: FnOnce(&T, usize) -> PyResult<PyIterReturn>,
122 {
123 self._next(f, |zelf| zelf.position += 1)
124 }
125
126 pub fn rev_next<F>(&mut self, f: F) -> PyResult<PyIterReturn>
127 where
128 F: FnOnce(&T, usize) -> PyResult<PyIterReturn>,
129 {
130 self._next(f, |zelf| {
131 if zelf.position == 0 {
132 zelf.status = IterStatus::Exhausted;
133 } else {
134 zelf.position -= 1;
135 }
136 })
137 }
138
139 pub fn length_hint<F>(&self, f: F) -> usize
140 where
141 F: FnOnce(&T) -> usize,
142 {
143 if let IterStatus::Active(obj) = &self.status {
144 f(obj).saturating_sub(self.position)
145 } else {
146 0
147 }
148 }
149
150 pub fn rev_length_hint<F>(&self, f: F) -> usize
151 where
152 F: FnOnce(&T) -> usize,
153 {
154 if let IterStatus::Active(obj) = &self.status {
155 if self.position <= f(obj) {
156 return self.position + 1;
157 }
158 }
159 0
160 }
161}
162
163pub fn builtins_iter(vm: &VirtualMachine) -> &PyObject {
164 static_cell! {
165 static INSTANCE: PyObjectRef;
166 }
167 INSTANCE.get_or_init(|| vm.builtins.get_attr("iter", vm).unwrap())
168}
169
170pub fn builtins_reversed(vm: &VirtualMachine) -> &PyObject {
171 static_cell! {
172 static INSTANCE: PyObjectRef;
173 }
174 INSTANCE.get_or_init(|| vm.builtins.get_attr("reversed", vm).unwrap())
175}
176
177#[pyclass(module = false, name = "iterator", traverse)]
178#[derive(Debug)]
179pub struct PySequenceIterator {
180 #[pytraverse(skip)]
182 seq_methods: &'static PySequenceMethods,
183 internal: PyMutex<PositionIterInternal<PyObjectRef>>,
184}
185
186impl PyPayload for PySequenceIterator {
187 fn class(ctx: &Context) -> &'static Py<PyType> {
188 ctx.types.iter_type
189 }
190}
191
192#[pyclass(with(IterNext, Iterable))]
193impl PySequenceIterator {
194 pub fn new(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<Self> {
195 let seq = PySequence::try_protocol(&obj, vm)?;
196 Ok(Self {
197 seq_methods: seq.methods,
198 internal: PyMutex::new(PositionIterInternal::new(obj, 0)),
199 })
200 }
201
202 #[pymethod(magic)]
203 fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef {
204 let internal = self.internal.lock();
205 if let IterStatus::Active(obj) = &internal.status {
206 let seq = PySequence {
207 obj,
208 methods: self.seq_methods,
209 };
210 seq.length(vm)
211 .map(|x| PyInt::from(x).into_pyobject(vm))
212 .unwrap_or_else(|_| vm.ctx.not_implemented())
213 } else {
214 PyInt::from(0).into_pyobject(vm)
215 }
216 }
217
218 #[pymethod(magic)]
219 fn reduce(&self, vm: &VirtualMachine) -> PyTupleRef {
220 self.internal.lock().builtins_iter_reduce(|x| x.clone(), vm)
221 }
222
223 #[pymethod(magic)]
224 fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
225 self.internal.lock().set_state(state, |_, pos| pos, vm)
226 }
227}
228
229impl SelfIter for PySequenceIterator {}
230impl IterNext for PySequenceIterator {
231 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
232 zelf.internal.lock().next(|obj, pos| {
233 let seq = PySequence {
234 obj,
235 methods: zelf.seq_methods,
236 };
237 PyIterReturn::from_getitem_result(seq.get_item(pos as isize, vm), vm)
238 })
239 }
240}
241
242#[pyclass(module = false, name = "callable_iterator", traverse)]
243#[derive(Debug)]
244pub struct PyCallableIterator {
245 sentinel: PyObjectRef,
246 status: PyRwLock<IterStatus<ArgCallable>>,
247}
248
249impl PyPayload for PyCallableIterator {
250 fn class(ctx: &Context) -> &'static Py<PyType> {
251 ctx.types.callable_iterator
252 }
253}
254
255#[pyclass(with(IterNext, Iterable))]
256impl PyCallableIterator {
257 pub fn new(callable: ArgCallable, sentinel: PyObjectRef) -> Self {
258 Self {
259 sentinel,
260 status: PyRwLock::new(IterStatus::Active(callable)),
261 }
262 }
263}
264
265impl SelfIter for PyCallableIterator {}
266impl IterNext for PyCallableIterator {
267 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
268 let status = zelf.status.upgradable_read();
269 let next = if let IterStatus::Active(callable) = &*status {
270 let ret = callable.invoke((), vm)?;
271 if vm.bool_eq(&ret, &zelf.sentinel)? {
272 *PyRwLockUpgradableReadGuard::upgrade(status) = IterStatus::Exhausted;
273 PyIterReturn::StopIteration(None)
274 } else {
275 PyIterReturn::Return(ret)
276 }
277 } else {
278 PyIterReturn::StopIteration(None)
279 };
280 Ok(next)
281 }
282}
283
284pub fn init(context: &Context) {
285 PySequenceIterator::extend_class(context, context.types.iter_type);
286 PyCallableIterator::extend_class(context, context.types.callable_iterator);
287}