rustpython_vm/builtins/
asyncgenerator.rs

1use super::{PyCode, PyGenericAlias, PyStrRef, PyType, PyTypeRef};
2use crate::{
3    builtins::PyBaseExceptionRef,
4    class::PyClassImpl,
5    coroutine::Coro,
6    frame::FrameRef,
7    function::OptionalArg,
8    protocol::PyIterReturn,
9    types::{IterNext, Iterable, Representable, SelfIter, Unconstructible},
10    AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
11};
12
13use crossbeam_utils::atomic::AtomicCell;
14
15#[pyclass(name = "async_generator", module = false)]
16#[derive(Debug)]
17pub struct PyAsyncGen {
18    inner: Coro,
19    running_async: AtomicCell<bool>,
20}
21type PyAsyncGenRef = PyRef<PyAsyncGen>;
22
23impl PyPayload for PyAsyncGen {
24    fn class(ctx: &Context) -> &'static Py<PyType> {
25        ctx.types.async_generator
26    }
27}
28
29#[pyclass(with(PyRef, Unconstructible, Representable))]
30impl PyAsyncGen {
31    pub fn as_coro(&self) -> &Coro {
32        &self.inner
33    }
34
35    pub fn new(frame: FrameRef, name: PyStrRef) -> Self {
36        PyAsyncGen {
37            inner: Coro::new(frame, name),
38            running_async: AtomicCell::new(false),
39        }
40    }
41
42    #[pygetset(magic)]
43    fn name(&self) -> PyStrRef {
44        self.inner.name()
45    }
46
47    #[pygetset(magic, setter)]
48    fn set_name(&self, name: PyStrRef) {
49        self.inner.set_name(name)
50    }
51
52    #[pygetset]
53    fn ag_await(&self, _vm: &VirtualMachine) -> Option<PyObjectRef> {
54        self.inner.frame().yield_from_target()
55    }
56    #[pygetset]
57    fn ag_frame(&self, _vm: &VirtualMachine) -> FrameRef {
58        self.inner.frame()
59    }
60    #[pygetset]
61    fn ag_running(&self, _vm: &VirtualMachine) -> bool {
62        self.inner.running()
63    }
64    #[pygetset]
65    fn ag_code(&self, _vm: &VirtualMachine) -> PyRef<PyCode> {
66        self.inner.frame().code.clone()
67    }
68
69    #[pyclassmethod(magic)]
70    fn class_getitem(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
71        PyGenericAlias::new(cls, args, vm)
72    }
73}
74
75#[pyclass]
76impl PyRef<PyAsyncGen> {
77    #[pymethod(magic)]
78    fn aiter(self, _vm: &VirtualMachine) -> PyRef<PyAsyncGen> {
79        self
80    }
81
82    #[pymethod(magic)]
83    fn anext(self, vm: &VirtualMachine) -> PyAsyncGenASend {
84        Self::asend(self, vm.ctx.none(), vm)
85    }
86
87    #[pymethod]
88    fn asend(self, value: PyObjectRef, _vm: &VirtualMachine) -> PyAsyncGenASend {
89        PyAsyncGenASend {
90            ag: self,
91            state: AtomicCell::new(AwaitableState::Init),
92            value,
93        }
94    }
95
96    #[pymethod]
97    fn athrow(
98        self,
99        exc_type: PyObjectRef,
100        exc_val: OptionalArg,
101        exc_tb: OptionalArg,
102        vm: &VirtualMachine,
103    ) -> PyAsyncGenAThrow {
104        PyAsyncGenAThrow {
105            ag: self,
106            aclose: false,
107            state: AtomicCell::new(AwaitableState::Init),
108            value: (
109                exc_type,
110                exc_val.unwrap_or_none(vm),
111                exc_tb.unwrap_or_none(vm),
112            ),
113        }
114    }
115
116    #[pymethod]
117    fn aclose(self, vm: &VirtualMachine) -> PyAsyncGenAThrow {
118        PyAsyncGenAThrow {
119            ag: self,
120            aclose: true,
121            state: AtomicCell::new(AwaitableState::Init),
122            value: (
123                vm.ctx.exceptions.generator_exit.to_owned().into(),
124                vm.ctx.none(),
125                vm.ctx.none(),
126            ),
127        }
128    }
129}
130
131impl Representable for PyAsyncGen {
132    #[inline]
133    fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
134        Ok(zelf.inner.repr(zelf.as_object(), zelf.get_id(), vm))
135    }
136}
137
138impl Unconstructible for PyAsyncGen {}
139
140#[pyclass(module = false, name = "async_generator_wrapped_value")]
141#[derive(Debug)]
142pub(crate) struct PyAsyncGenWrappedValue(pub PyObjectRef);
143impl PyPayload for PyAsyncGenWrappedValue {
144    fn class(ctx: &Context) -> &'static Py<PyType> {
145        ctx.types.async_generator_wrapped_value
146    }
147}
148
149#[pyclass]
150impl PyAsyncGenWrappedValue {}
151
152impl PyAsyncGenWrappedValue {
153    fn unbox(ag: &PyAsyncGen, val: PyResult<PyIterReturn>, vm: &VirtualMachine) -> PyResult {
154        let (closed, async_done) = match &val {
155            Ok(PyIterReturn::StopIteration(_)) => (true, true),
156            Err(e) if e.fast_isinstance(vm.ctx.exceptions.generator_exit) => (true, true),
157            Err(_) => (false, true),
158            _ => (false, false),
159        };
160        if closed {
161            ag.inner.closed.store(true);
162        }
163        if async_done {
164            ag.running_async.store(false);
165        }
166        let val = val?.into_async_pyresult(vm)?;
167        match_class!(match val {
168            val @ Self => {
169                ag.running_async.store(false);
170                Err(vm.new_stop_iteration(Some(val.0.clone())))
171            }
172            val => Ok(val),
173        })
174    }
175}
176
177#[derive(Debug, Clone, Copy)]
178enum AwaitableState {
179    Init,
180    Iter,
181    Closed,
182}
183
184#[pyclass(module = false, name = "async_generator_asend")]
185#[derive(Debug)]
186pub(crate) struct PyAsyncGenASend {
187    ag: PyAsyncGenRef,
188    state: AtomicCell<AwaitableState>,
189    value: PyObjectRef,
190}
191
192impl PyPayload for PyAsyncGenASend {
193    fn class(ctx: &Context) -> &'static Py<PyType> {
194        ctx.types.async_generator_asend
195    }
196}
197
198#[pyclass(with(IterNext, Iterable))]
199impl PyAsyncGenASend {
200    #[pymethod(name = "__await__")]
201    fn r#await(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
202        zelf
203    }
204
205    #[pymethod]
206    fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
207        let val = match self.state.load() {
208            AwaitableState::Closed => {
209                return Err(vm.new_runtime_error(
210                    "cannot reuse already awaited __anext__()/asend()".to_owned(),
211                ))
212            }
213            AwaitableState::Iter => val, // already running, all good
214            AwaitableState::Init => {
215                if self.ag.running_async.load() {
216                    return Err(vm.new_runtime_error(
217                        "anext(): asynchronous generator is already running".to_owned(),
218                    ));
219                }
220                self.ag.running_async.store(true);
221                self.state.store(AwaitableState::Iter);
222                if vm.is_none(&val) {
223                    self.value.clone()
224                } else {
225                    val
226                }
227            }
228        };
229        let res = self.ag.inner.send(self.ag.as_object(), val, vm);
230        let res = PyAsyncGenWrappedValue::unbox(&self.ag, res, vm);
231        if res.is_err() {
232            self.close();
233        }
234        res
235    }
236
237    #[pymethod]
238    fn throw(
239        &self,
240        exc_type: PyObjectRef,
241        exc_val: OptionalArg,
242        exc_tb: OptionalArg,
243        vm: &VirtualMachine,
244    ) -> PyResult {
245        if let AwaitableState::Closed = self.state.load() {
246            return Err(
247                vm.new_runtime_error("cannot reuse already awaited __anext__()/asend()".to_owned())
248            );
249        }
250
251        let res = self.ag.inner.throw(
252            self.ag.as_object(),
253            exc_type,
254            exc_val.unwrap_or_none(vm),
255            exc_tb.unwrap_or_none(vm),
256            vm,
257        );
258        let res = PyAsyncGenWrappedValue::unbox(&self.ag, res, vm);
259        if res.is_err() {
260            self.close();
261        }
262        res
263    }
264
265    #[pymethod]
266    fn close(&self) {
267        self.state.store(AwaitableState::Closed);
268    }
269}
270
271impl SelfIter for PyAsyncGenASend {}
272impl IterNext for PyAsyncGenASend {
273    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
274        PyIterReturn::from_pyresult(zelf.send(vm.ctx.none(), vm), vm)
275    }
276}
277
278#[pyclass(module = false, name = "async_generator_athrow")]
279#[derive(Debug)]
280pub(crate) struct PyAsyncGenAThrow {
281    ag: PyAsyncGenRef,
282    aclose: bool,
283    state: AtomicCell<AwaitableState>,
284    value: (PyObjectRef, PyObjectRef, PyObjectRef),
285}
286
287impl PyPayload for PyAsyncGenAThrow {
288    fn class(ctx: &Context) -> &'static Py<PyType> {
289        ctx.types.async_generator_athrow
290    }
291}
292
293#[pyclass(with(IterNext, Iterable))]
294impl PyAsyncGenAThrow {
295    #[pymethod(name = "__await__")]
296    fn r#await(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
297        zelf
298    }
299
300    #[pymethod]
301    fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
302        match self.state.load() {
303            AwaitableState::Closed => {
304                Err(vm
305                    .new_runtime_error("cannot reuse already awaited aclose()/athrow()".to_owned()))
306            }
307            AwaitableState::Init => {
308                if self.ag.running_async.load() {
309                    self.state.store(AwaitableState::Closed);
310                    let msg = if self.aclose {
311                        "aclose(): asynchronous generator is already running"
312                    } else {
313                        "athrow(): asynchronous generator is already running"
314                    };
315                    return Err(vm.new_runtime_error(msg.to_owned()));
316                }
317                if self.ag.inner.closed() {
318                    self.state.store(AwaitableState::Closed);
319                    return Err(vm.new_stop_iteration(None));
320                }
321                if !vm.is_none(&val) {
322                    return Err(vm.new_runtime_error(
323                        "can't send non-None value to a just-started async generator".to_owned(),
324                    ));
325                }
326                self.state.store(AwaitableState::Iter);
327                self.ag.running_async.store(true);
328
329                let (ty, val, tb) = self.value.clone();
330                let ret = self.ag.inner.throw(self.ag.as_object(), ty, val, tb, vm);
331                let ret = if self.aclose {
332                    if self.ignored_close(&ret) {
333                        Err(self.yield_close(vm))
334                    } else {
335                        ret.and_then(|o| o.into_async_pyresult(vm))
336                    }
337                } else {
338                    PyAsyncGenWrappedValue::unbox(&self.ag, ret, vm)
339                };
340                ret.map_err(|e| self.check_error(e, vm))
341            }
342            AwaitableState::Iter => {
343                let ret = self.ag.inner.send(self.ag.as_object(), val, vm);
344                if self.aclose {
345                    match ret {
346                        Ok(PyIterReturn::Return(v)) if v.payload_is::<PyAsyncGenWrappedValue>() => {
347                            Err(self.yield_close(vm))
348                        }
349                        other => other
350                            .and_then(|o| o.into_async_pyresult(vm))
351                            .map_err(|e| self.check_error(e, vm)),
352                    }
353                } else {
354                    PyAsyncGenWrappedValue::unbox(&self.ag, ret, vm)
355                }
356            }
357        }
358    }
359
360    #[pymethod]
361    fn throw(
362        &self,
363        exc_type: PyObjectRef,
364        exc_val: OptionalArg,
365        exc_tb: OptionalArg,
366        vm: &VirtualMachine,
367    ) -> PyResult {
368        let ret = self.ag.inner.throw(
369            self.ag.as_object(),
370            exc_type,
371            exc_val.unwrap_or_none(vm),
372            exc_tb.unwrap_or_none(vm),
373            vm,
374        );
375        let res = if self.aclose {
376            if self.ignored_close(&ret) {
377                Err(self.yield_close(vm))
378            } else {
379                ret.and_then(|o| o.into_async_pyresult(vm))
380            }
381        } else {
382            PyAsyncGenWrappedValue::unbox(&self.ag, ret, vm)
383        };
384        res.map_err(|e| self.check_error(e, vm))
385    }
386
387    #[pymethod]
388    fn close(&self) {
389        self.state.store(AwaitableState::Closed);
390    }
391
392    fn ignored_close(&self, res: &PyResult<PyIterReturn>) -> bool {
393        res.as_ref().map_or(false, |v| match v {
394            PyIterReturn::Return(obj) => obj.payload_is::<PyAsyncGenWrappedValue>(),
395            PyIterReturn::StopIteration(_) => false,
396        })
397    }
398    fn yield_close(&self, vm: &VirtualMachine) -> PyBaseExceptionRef {
399        self.ag.running_async.store(false);
400        self.state.store(AwaitableState::Closed);
401        vm.new_runtime_error("async generator ignored GeneratorExit".to_owned())
402    }
403    fn check_error(&self, exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyBaseExceptionRef {
404        self.ag.running_async.store(false);
405        self.state.store(AwaitableState::Closed);
406        if self.aclose
407            && (exc.fast_isinstance(vm.ctx.exceptions.stop_async_iteration)
408                || exc.fast_isinstance(vm.ctx.exceptions.generator_exit))
409        {
410            vm.new_stop_iteration(None)
411        } else {
412            exc
413        }
414    }
415}
416
417impl SelfIter for PyAsyncGenAThrow {}
418impl IterNext for PyAsyncGenAThrow {
419    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
420        PyIterReturn::from_pyresult(zelf.send(vm.ctx.none(), vm), vm)
421    }
422}
423
424pub fn init(ctx: &Context) {
425    PyAsyncGen::extend_class(ctx, ctx.types.async_generator);
426    PyAsyncGenASend::extend_class(ctx, ctx.types.async_generator_asend);
427    PyAsyncGenAThrow::extend_class(ctx, ctx.types.async_generator_athrow);
428}
Morty Proxy This is a proxified and sanitized view of the page, visit original site.