1use super::{PyCode, PyGenericAlias, PyStrRef, PyType, PyTypeRef};
2use crate::{
3 builtins::PyBaseExceptionRef,
4 class::PyClassImpl,
5 coroutine::Coro,
6 frame::FrameRef,
7 function::OptionalArg,
8 protocol::PyIterReturn,
9 types::{IterNext, Iterable, Representable, SelfIter, Unconstructible},
10 AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
11};
12
13use crossbeam_utils::atomic::AtomicCell;
14
15#[pyclass(name = "async_generator", module = false)]
16#[derive(Debug)]
17pub struct PyAsyncGen {
18 inner: Coro,
19 running_async: AtomicCell<bool>,
20}
21type PyAsyncGenRef = PyRef<PyAsyncGen>;
22
23impl PyPayload for PyAsyncGen {
24 fn class(ctx: &Context) -> &'static Py<PyType> {
25 ctx.types.async_generator
26 }
27}
28
29#[pyclass(with(PyRef, Unconstructible, Representable))]
30impl PyAsyncGen {
31 pub fn as_coro(&self) -> &Coro {
32 &self.inner
33 }
34
35 pub fn new(frame: FrameRef, name: PyStrRef) -> Self {
36 PyAsyncGen {
37 inner: Coro::new(frame, name),
38 running_async: AtomicCell::new(false),
39 }
40 }
41
42 #[pygetset(magic)]
43 fn name(&self) -> PyStrRef {
44 self.inner.name()
45 }
46
47 #[pygetset(magic, setter)]
48 fn set_name(&self, name: PyStrRef) {
49 self.inner.set_name(name)
50 }
51
52 #[pygetset]
53 fn ag_await(&self, _vm: &VirtualMachine) -> Option<PyObjectRef> {
54 self.inner.frame().yield_from_target()
55 }
56 #[pygetset]
57 fn ag_frame(&self, _vm: &VirtualMachine) -> FrameRef {
58 self.inner.frame()
59 }
60 #[pygetset]
61 fn ag_running(&self, _vm: &VirtualMachine) -> bool {
62 self.inner.running()
63 }
64 #[pygetset]
65 fn ag_code(&self, _vm: &VirtualMachine) -> PyRef<PyCode> {
66 self.inner.frame().code.clone()
67 }
68
69 #[pyclassmethod(magic)]
70 fn class_getitem(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
71 PyGenericAlias::new(cls, args, vm)
72 }
73}
74
75#[pyclass]
76impl PyRef<PyAsyncGen> {
77 #[pymethod(magic)]
78 fn aiter(self, _vm: &VirtualMachine) -> PyRef<PyAsyncGen> {
79 self
80 }
81
82 #[pymethod(magic)]
83 fn anext(self, vm: &VirtualMachine) -> PyAsyncGenASend {
84 Self::asend(self, vm.ctx.none(), vm)
85 }
86
87 #[pymethod]
88 fn asend(self, value: PyObjectRef, _vm: &VirtualMachine) -> PyAsyncGenASend {
89 PyAsyncGenASend {
90 ag: self,
91 state: AtomicCell::new(AwaitableState::Init),
92 value,
93 }
94 }
95
96 #[pymethod]
97 fn athrow(
98 self,
99 exc_type: PyObjectRef,
100 exc_val: OptionalArg,
101 exc_tb: OptionalArg,
102 vm: &VirtualMachine,
103 ) -> PyAsyncGenAThrow {
104 PyAsyncGenAThrow {
105 ag: self,
106 aclose: false,
107 state: AtomicCell::new(AwaitableState::Init),
108 value: (
109 exc_type,
110 exc_val.unwrap_or_none(vm),
111 exc_tb.unwrap_or_none(vm),
112 ),
113 }
114 }
115
116 #[pymethod]
117 fn aclose(self, vm: &VirtualMachine) -> PyAsyncGenAThrow {
118 PyAsyncGenAThrow {
119 ag: self,
120 aclose: true,
121 state: AtomicCell::new(AwaitableState::Init),
122 value: (
123 vm.ctx.exceptions.generator_exit.to_owned().into(),
124 vm.ctx.none(),
125 vm.ctx.none(),
126 ),
127 }
128 }
129}
130
131impl Representable for PyAsyncGen {
132 #[inline]
133 fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
134 Ok(zelf.inner.repr(zelf.as_object(), zelf.get_id(), vm))
135 }
136}
137
138impl Unconstructible for PyAsyncGen {}
139
140#[pyclass(module = false, name = "async_generator_wrapped_value")]
141#[derive(Debug)]
142pub(crate) struct PyAsyncGenWrappedValue(pub PyObjectRef);
143impl PyPayload for PyAsyncGenWrappedValue {
144 fn class(ctx: &Context) -> &'static Py<PyType> {
145 ctx.types.async_generator_wrapped_value
146 }
147}
148
149#[pyclass]
150impl PyAsyncGenWrappedValue {}
151
152impl PyAsyncGenWrappedValue {
153 fn unbox(ag: &PyAsyncGen, val: PyResult<PyIterReturn>, vm: &VirtualMachine) -> PyResult {
154 let (closed, async_done) = match &val {
155 Ok(PyIterReturn::StopIteration(_)) => (true, true),
156 Err(e) if e.fast_isinstance(vm.ctx.exceptions.generator_exit) => (true, true),
157 Err(_) => (false, true),
158 _ => (false, false),
159 };
160 if closed {
161 ag.inner.closed.store(true);
162 }
163 if async_done {
164 ag.running_async.store(false);
165 }
166 let val = val?.into_async_pyresult(vm)?;
167 match_class!(match val {
168 val @ Self => {
169 ag.running_async.store(false);
170 Err(vm.new_stop_iteration(Some(val.0.clone())))
171 }
172 val => Ok(val),
173 })
174 }
175}
176
177#[derive(Debug, Clone, Copy)]
178enum AwaitableState {
179 Init,
180 Iter,
181 Closed,
182}
183
184#[pyclass(module = false, name = "async_generator_asend")]
185#[derive(Debug)]
186pub(crate) struct PyAsyncGenASend {
187 ag: PyAsyncGenRef,
188 state: AtomicCell<AwaitableState>,
189 value: PyObjectRef,
190}
191
192impl PyPayload for PyAsyncGenASend {
193 fn class(ctx: &Context) -> &'static Py<PyType> {
194 ctx.types.async_generator_asend
195 }
196}
197
198#[pyclass(with(IterNext, Iterable))]
199impl PyAsyncGenASend {
200 #[pymethod(name = "__await__")]
201 fn r#await(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
202 zelf
203 }
204
205 #[pymethod]
206 fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
207 let val = match self.state.load() {
208 AwaitableState::Closed => {
209 return Err(vm.new_runtime_error(
210 "cannot reuse already awaited __anext__()/asend()".to_owned(),
211 ))
212 }
213 AwaitableState::Iter => val, AwaitableState::Init => {
215 if self.ag.running_async.load() {
216 return Err(vm.new_runtime_error(
217 "anext(): asynchronous generator is already running".to_owned(),
218 ));
219 }
220 self.ag.running_async.store(true);
221 self.state.store(AwaitableState::Iter);
222 if vm.is_none(&val) {
223 self.value.clone()
224 } else {
225 val
226 }
227 }
228 };
229 let res = self.ag.inner.send(self.ag.as_object(), val, vm);
230 let res = PyAsyncGenWrappedValue::unbox(&self.ag, res, vm);
231 if res.is_err() {
232 self.close();
233 }
234 res
235 }
236
237 #[pymethod]
238 fn throw(
239 &self,
240 exc_type: PyObjectRef,
241 exc_val: OptionalArg,
242 exc_tb: OptionalArg,
243 vm: &VirtualMachine,
244 ) -> PyResult {
245 if let AwaitableState::Closed = self.state.load() {
246 return Err(
247 vm.new_runtime_error("cannot reuse already awaited __anext__()/asend()".to_owned())
248 );
249 }
250
251 let res = self.ag.inner.throw(
252 self.ag.as_object(),
253 exc_type,
254 exc_val.unwrap_or_none(vm),
255 exc_tb.unwrap_or_none(vm),
256 vm,
257 );
258 let res = PyAsyncGenWrappedValue::unbox(&self.ag, res, vm);
259 if res.is_err() {
260 self.close();
261 }
262 res
263 }
264
265 #[pymethod]
266 fn close(&self) {
267 self.state.store(AwaitableState::Closed);
268 }
269}
270
271impl SelfIter for PyAsyncGenASend {}
272impl IterNext for PyAsyncGenASend {
273 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
274 PyIterReturn::from_pyresult(zelf.send(vm.ctx.none(), vm), vm)
275 }
276}
277
278#[pyclass(module = false, name = "async_generator_athrow")]
279#[derive(Debug)]
280pub(crate) struct PyAsyncGenAThrow {
281 ag: PyAsyncGenRef,
282 aclose: bool,
283 state: AtomicCell<AwaitableState>,
284 value: (PyObjectRef, PyObjectRef, PyObjectRef),
285}
286
287impl PyPayload for PyAsyncGenAThrow {
288 fn class(ctx: &Context) -> &'static Py<PyType> {
289 ctx.types.async_generator_athrow
290 }
291}
292
293#[pyclass(with(IterNext, Iterable))]
294impl PyAsyncGenAThrow {
295 #[pymethod(name = "__await__")]
296 fn r#await(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
297 zelf
298 }
299
300 #[pymethod]
301 fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
302 match self.state.load() {
303 AwaitableState::Closed => {
304 Err(vm
305 .new_runtime_error("cannot reuse already awaited aclose()/athrow()".to_owned()))
306 }
307 AwaitableState::Init => {
308 if self.ag.running_async.load() {
309 self.state.store(AwaitableState::Closed);
310 let msg = if self.aclose {
311 "aclose(): asynchronous generator is already running"
312 } else {
313 "athrow(): asynchronous generator is already running"
314 };
315 return Err(vm.new_runtime_error(msg.to_owned()));
316 }
317 if self.ag.inner.closed() {
318 self.state.store(AwaitableState::Closed);
319 return Err(vm.new_stop_iteration(None));
320 }
321 if !vm.is_none(&val) {
322 return Err(vm.new_runtime_error(
323 "can't send non-None value to a just-started async generator".to_owned(),
324 ));
325 }
326 self.state.store(AwaitableState::Iter);
327 self.ag.running_async.store(true);
328
329 let (ty, val, tb) = self.value.clone();
330 let ret = self.ag.inner.throw(self.ag.as_object(), ty, val, tb, vm);
331 let ret = if self.aclose {
332 if self.ignored_close(&ret) {
333 Err(self.yield_close(vm))
334 } else {
335 ret.and_then(|o| o.into_async_pyresult(vm))
336 }
337 } else {
338 PyAsyncGenWrappedValue::unbox(&self.ag, ret, vm)
339 };
340 ret.map_err(|e| self.check_error(e, vm))
341 }
342 AwaitableState::Iter => {
343 let ret = self.ag.inner.send(self.ag.as_object(), val, vm);
344 if self.aclose {
345 match ret {
346 Ok(PyIterReturn::Return(v)) if v.payload_is::<PyAsyncGenWrappedValue>() => {
347 Err(self.yield_close(vm))
348 }
349 other => other
350 .and_then(|o| o.into_async_pyresult(vm))
351 .map_err(|e| self.check_error(e, vm)),
352 }
353 } else {
354 PyAsyncGenWrappedValue::unbox(&self.ag, ret, vm)
355 }
356 }
357 }
358 }
359
360 #[pymethod]
361 fn throw(
362 &self,
363 exc_type: PyObjectRef,
364 exc_val: OptionalArg,
365 exc_tb: OptionalArg,
366 vm: &VirtualMachine,
367 ) -> PyResult {
368 let ret = self.ag.inner.throw(
369 self.ag.as_object(),
370 exc_type,
371 exc_val.unwrap_or_none(vm),
372 exc_tb.unwrap_or_none(vm),
373 vm,
374 );
375 let res = if self.aclose {
376 if self.ignored_close(&ret) {
377 Err(self.yield_close(vm))
378 } else {
379 ret.and_then(|o| o.into_async_pyresult(vm))
380 }
381 } else {
382 PyAsyncGenWrappedValue::unbox(&self.ag, ret, vm)
383 };
384 res.map_err(|e| self.check_error(e, vm))
385 }
386
387 #[pymethod]
388 fn close(&self) {
389 self.state.store(AwaitableState::Closed);
390 }
391
392 fn ignored_close(&self, res: &PyResult<PyIterReturn>) -> bool {
393 res.as_ref().map_or(false, |v| match v {
394 PyIterReturn::Return(obj) => obj.payload_is::<PyAsyncGenWrappedValue>(),
395 PyIterReturn::StopIteration(_) => false,
396 })
397 }
398 fn yield_close(&self, vm: &VirtualMachine) -> PyBaseExceptionRef {
399 self.ag.running_async.store(false);
400 self.state.store(AwaitableState::Closed);
401 vm.new_runtime_error("async generator ignored GeneratorExit".to_owned())
402 }
403 fn check_error(&self, exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyBaseExceptionRef {
404 self.ag.running_async.store(false);
405 self.state.store(AwaitableState::Closed);
406 if self.aclose
407 && (exc.fast_isinstance(vm.ctx.exceptions.stop_async_iteration)
408 || exc.fast_isinstance(vm.ctx.exceptions.generator_exit))
409 {
410 vm.new_stop_iteration(None)
411 } else {
412 exc
413 }
414 }
415}
416
417impl SelfIter for PyAsyncGenAThrow {}
418impl IterNext for PyAsyncGenAThrow {
419 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
420 PyIterReturn::from_pyresult(zelf.send(vm.ctx.none(), vm), vm)
421 }
422}
423
424pub fn init(ctx: &Context) {
425 PyAsyncGen::extend_class(ctx, ctx.types.async_generator);
426 PyAsyncGenASend::extend_class(ctx, ctx.types.async_generator_asend);
427 PyAsyncGenAThrow::extend_class(ctx, ctx.types.async_generator_athrow);
428}