1use super::{float, PyStr, PyType, PyTypeRef};
2use crate::{
3 class::PyClassImpl,
4 convert::{ToPyObject, ToPyResult},
5 function::{
6 OptionalArg, OptionalOption,
7 PyArithmeticValue::{self, *},
8 PyComparisonValue,
9 },
10 identifier,
11 protocol::PyNumberMethods,
12 stdlib::warnings,
13 types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp, Representable},
14 AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
15};
16use num_complex::Complex64;
17use num_traits::Zero;
18use rustpython_common::hash;
19use std::num::Wrapping;
20
21#[pyclass(module = false, name = "complex")]
25#[derive(Debug, Copy, Clone, PartialEq)]
26pub struct PyComplex {
27 value: Complex64,
28}
29
30impl PyComplex {
31 pub fn to_complex64(self) -> Complex64 {
32 self.value
33 }
34}
35
36impl PyPayload for PyComplex {
37 fn class(ctx: &Context) -> &'static Py<PyType> {
38 ctx.types.complex_type
39 }
40}
41
42impl ToPyObject for Complex64 {
43 fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef {
44 PyComplex::new_ref(self, &vm.ctx).into()
45 }
46}
47
48impl From<Complex64> for PyComplex {
49 fn from(value: Complex64) -> Self {
50 PyComplex { value }
51 }
52}
53
54impl PyObjectRef {
55 pub fn try_complex(&self, vm: &VirtualMachine) -> PyResult<Option<(Complex64, bool)>> {
58 if let Some(complex) = self.payload_if_exact::<PyComplex>(vm) {
59 return Ok(Some((complex.value, true)));
60 }
61 if let Some(method) = vm.get_method(self.clone(), identifier!(vm, __complex__)) {
62 let result = method?.call((), vm)?;
63
64 let ret_class = result.class().to_owned();
65 if let Some(ret) = result.downcast_ref::<PyComplex>() {
66 warnings::warn(
67 vm.ctx.exceptions.deprecation_warning,
68 format!(
69 "__complex__ returned non-complex (type {}). \
70 The ability to return an instance of a strict subclass of complex \
71 is deprecated, and may be removed in a future version of Python.",
72 ret_class
73 ),
74 1,
75 vm,
76 )?;
77
78 return Ok(Some((ret.value, true)));
79 } else {
80 return match result.payload::<PyComplex>() {
81 Some(complex_obj) => Ok(Some((complex_obj.value, true))),
82 None => Err(vm.new_type_error(format!(
83 "__complex__ returned non-complex (type '{}')",
84 result.class().name()
85 ))),
86 };
87 }
88 }
89 if let Some(complex) = self.payload_if_subclass::<PyComplex>(vm) {
92 return Ok(Some((complex.value, true)));
93 }
94 if let Some(float) = self.try_float_opt(vm) {
95 return Ok(Some((Complex64::new(float?.to_f64(), 0.0), false)));
96 }
97 Ok(None)
98 }
99}
100
101pub fn init(context: &Context) {
102 PyComplex::extend_class(context, context.types.complex_type);
103}
104
105fn to_op_complex(value: &PyObject, vm: &VirtualMachine) -> PyResult<Option<Complex64>> {
106 let r = if let Some(complex) = value.payload_if_subclass::<PyComplex>(vm) {
107 Some(complex.value)
108 } else {
109 float::to_op_float(value, vm)?.map(|float| Complex64::new(float, 0.0))
110 };
111 Ok(r)
112}
113
114fn inner_div(v1: Complex64, v2: Complex64, vm: &VirtualMachine) -> PyResult<Complex64> {
115 if v2.is_zero() {
116 return Err(vm.new_zero_division_error("complex division by zero".to_owned()));
117 }
118
119 Ok(v1.fdiv(v2))
120}
121
122fn inner_pow(v1: Complex64, v2: Complex64, vm: &VirtualMachine) -> PyResult<Complex64> {
123 if v1.is_zero() {
124 return if v2.re < 0.0 || v2.im != 0.0 {
125 let msg = format!("{v1} cannot be raised to a negative or complex power");
126 Err(vm.new_zero_division_error(msg))
127 } else if v2.is_zero() {
128 Ok(Complex64::new(1.0, 0.0))
129 } else {
130 Ok(Complex64::new(0.0, 0.0))
131 };
132 }
133
134 let ans = v1.powc(v2);
135 if ans.is_infinite() && !(v1.is_infinite() || v2.is_infinite()) {
136 Err(vm.new_overflow_error("complex exponentiation overflow".to_owned()))
137 } else {
138 Ok(ans)
139 }
140}
141
142impl Constructor for PyComplex {
143 type Args = ComplexArgs;
144
145 fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult {
146 let imag_missing = args.imag.is_missing();
147 let (real, real_was_complex) = match args.real {
148 OptionalArg::Missing => (Complex64::new(0.0, 0.0), false),
149 OptionalArg::Present(val) => {
150 let val = if cls.is(vm.ctx.types.complex_type) && imag_missing {
151 match val.downcast_exact::<PyComplex>(vm) {
152 Ok(c) => {
153 return Ok(c.into_pyref().into());
154 }
155 Err(val) => val,
156 }
157 } else {
158 val
159 };
160
161 if let Some(c) = val.try_complex(vm)? {
162 c
163 } else if let Some(s) = val.payload_if_subclass::<PyStr>(vm) {
164 if args.imag.is_present() {
165 return Err(vm.new_type_error(
166 "complex() can't take second arg if first is a string".to_owned(),
167 ));
168 }
169 let value = parse_str(s.as_str().trim()).ok_or_else(|| {
170 vm.new_value_error("complex() arg is a malformed string".to_owned())
171 })?;
172 return Self::from(value)
173 .into_ref_with_type(vm, cls)
174 .map(Into::into);
175 } else {
176 return Err(vm.new_type_error(format!(
177 "complex() first argument must be a string or a number, not '{}'",
178 val.class().name()
179 )));
180 }
181 }
182 };
183
184 let (imag, imag_was_complex) = match args.imag {
185 OptionalArg::Missing => (Complex64::new(real.im, 0.0), false),
188 OptionalArg::Present(obj) => {
189 if let Some(c) = obj.try_complex(vm)? {
190 c
191 } else if obj.class().fast_issubclass(vm.ctx.types.str_type) {
192 return Err(
193 vm.new_type_error("complex() second arg can't be a string".to_owned())
194 );
195 } else {
196 return Err(vm.new_type_error(format!(
197 "complex() second argument must be a number, not '{}'",
198 obj.class().name()
199 )));
200 }
201 }
202 };
203
204 let final_real = if imag_was_complex {
205 real.re - imag.im
206 } else {
207 real.re
208 };
209
210 let final_imag = if real_was_complex && !imag_missing {
211 imag.re + real.im
212 } else {
213 imag.re
214 };
215 let value = Complex64::new(final_real, final_imag);
216 Self::from(value)
217 .into_ref_with_type(vm, cls)
218 .map(Into::into)
219 }
220}
221
222impl PyComplex {
223 pub fn new_ref(value: Complex64, ctx: &Context) -> PyRef<Self> {
224 PyRef::new_ref(Self::from(value), ctx.types.complex_type.to_owned(), None)
225 }
226
227 pub fn to_complex(&self) -> Complex64 {
228 self.value
229 }
230}
231
232#[pyclass(
233 flags(BASETYPE),
234 with(PyRef, Comparable, Hashable, Constructor, AsNumber, Representable)
235)]
236impl PyComplex {
237 #[pygetset]
238 fn real(&self) -> f64 {
239 self.value.re
240 }
241
242 #[pygetset]
243 fn imag(&self) -> f64 {
244 self.value.im
245 }
246
247 #[pymethod(magic)]
248 fn abs(&self, vm: &VirtualMachine) -> PyResult<f64> {
249 let Complex64 { im, re } = self.value;
250 let is_finite = im.is_finite() && re.is_finite();
251 let abs_result = re.hypot(im);
252 if is_finite && abs_result.is_infinite() {
253 Err(vm.new_overflow_error("absolute value too large".to_string()))
254 } else {
255 Ok(abs_result)
256 }
257 }
258
259 #[inline]
260 fn op<F>(
261 &self,
262 other: PyObjectRef,
263 op: F,
264 vm: &VirtualMachine,
265 ) -> PyResult<PyArithmeticValue<Complex64>>
266 where
267 F: Fn(Complex64, Complex64) -> PyResult<Complex64>,
268 {
269 to_op_complex(&other, vm)?.map_or_else(
270 || Ok(NotImplemented),
271 |other| Ok(Implemented(op(self.value, other)?)),
272 )
273 }
274
275 #[pymethod(name = "__radd__")]
276 #[pymethod(magic)]
277 fn add(
278 &self,
279 other: PyObjectRef,
280 vm: &VirtualMachine,
281 ) -> PyResult<PyArithmeticValue<Complex64>> {
282 self.op(other, |a, b| Ok(a + b), vm)
283 }
284
285 #[pymethod(magic)]
286 fn sub(
287 &self,
288 other: PyObjectRef,
289 vm: &VirtualMachine,
290 ) -> PyResult<PyArithmeticValue<Complex64>> {
291 self.op(other, |a, b| Ok(a - b), vm)
292 }
293
294 #[pymethod(magic)]
295 fn rsub(
296 &self,
297 other: PyObjectRef,
298 vm: &VirtualMachine,
299 ) -> PyResult<PyArithmeticValue<Complex64>> {
300 self.op(other, |a, b| Ok(b - a), vm)
301 }
302
303 #[pymethod]
304 fn conjugate(&self) -> Complex64 {
305 self.value.conj()
306 }
307
308 #[pymethod(name = "__rmul__")]
309 #[pymethod(magic)]
310 fn mul(
311 &self,
312 other: PyObjectRef,
313 vm: &VirtualMachine,
314 ) -> PyResult<PyArithmeticValue<Complex64>> {
315 self.op(other, |a, b| Ok(a * b), vm)
316 }
317
318 #[pymethod(magic)]
319 fn truediv(
320 &self,
321 other: PyObjectRef,
322 vm: &VirtualMachine,
323 ) -> PyResult<PyArithmeticValue<Complex64>> {
324 self.op(other, |a, b| inner_div(a, b, vm), vm)
325 }
326
327 #[pymethod(magic)]
328 fn rtruediv(
329 &self,
330 other: PyObjectRef,
331 vm: &VirtualMachine,
332 ) -> PyResult<PyArithmeticValue<Complex64>> {
333 self.op(other, |a, b| inner_div(b, a, vm), vm)
334 }
335
336 #[pymethod(magic)]
337 fn pos(&self) -> Complex64 {
338 self.value
339 }
340
341 #[pymethod(magic)]
342 fn neg(&self) -> Complex64 {
343 -self.value
344 }
345
346 #[pymethod(magic)]
347 fn pow(
348 &self,
349 other: PyObjectRef,
350 mod_val: OptionalOption<PyObjectRef>,
351 vm: &VirtualMachine,
352 ) -> PyResult<PyArithmeticValue<Complex64>> {
353 if mod_val.flatten().is_some() {
354 Err(vm.new_value_error("complex modulo not allowed".to_owned()))
355 } else {
356 self.op(other, |a, b| inner_pow(a, b, vm), vm)
357 }
358 }
359
360 #[pymethod(magic)]
361 fn rpow(
362 &self,
363 other: PyObjectRef,
364 vm: &VirtualMachine,
365 ) -> PyResult<PyArithmeticValue<Complex64>> {
366 self.op(other, |a, b| inner_pow(b, a, vm), vm)
367 }
368
369 #[pymethod(magic)]
370 fn bool(&self) -> bool {
371 !Complex64::is_zero(&self.value)
372 }
373
374 #[pymethod(magic)]
375 fn getnewargs(&self) -> (f64, f64) {
376 let Complex64 { re, im } = self.value;
377 (re, im)
378 }
379}
380
381#[pyclass]
382impl PyRef<PyComplex> {
383 #[pymethod(magic)]
384 fn complex(self, vm: &VirtualMachine) -> PyRef<PyComplex> {
385 if self.is(vm.ctx.types.complex_type) {
386 self
387 } else {
388 PyComplex::from(self.value).into_ref(&vm.ctx)
389 }
390 }
391}
392
393impl Comparable for PyComplex {
394 fn cmp(
395 zelf: &Py<Self>,
396 other: &PyObject,
397 op: PyComparisonOp,
398 vm: &VirtualMachine,
399 ) -> PyResult<PyComparisonValue> {
400 op.eq_only(|| {
401 let result = if let Some(other) = other.payload_if_subclass::<PyComplex>(vm) {
402 if zelf.value.re.is_nan()
403 && zelf.value.im.is_nan()
404 && other.value.re.is_nan()
405 && other.value.im.is_nan()
406 {
407 true
408 } else {
409 zelf.value == other.value
410 }
411 } else {
412 match float::to_op_float(other, vm) {
413 Ok(Some(other)) => zelf.value == other.into(),
414 Err(_) => false,
415 Ok(None) => return Ok(PyComparisonValue::NotImplemented),
416 }
417 };
418 Ok(PyComparisonValue::Implemented(result))
419 })
420 }
421}
422
423impl Hashable for PyComplex {
424 #[inline]
425 fn hash(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<hash::PyHash> {
426 let value = zelf.value;
427
428 let re_hash =
429 hash::hash_float(value.re).unwrap_or_else(|| hash::hash_object_id(zelf.get_id()));
430
431 let im_hash =
432 hash::hash_float(value.im).unwrap_or_else(|| hash::hash_object_id(zelf.get_id()));
433
434 let Wrapping(ret) = Wrapping(re_hash) + Wrapping(im_hash) * Wrapping(hash::IMAG);
435 Ok(hash::fix_sentinel(ret))
436 }
437}
438
439impl AsNumber for PyComplex {
440 fn as_number() -> &'static PyNumberMethods {
441 static AS_NUMBER: PyNumberMethods = PyNumberMethods {
442 add: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a + b, vm)),
443 subtract: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a - b, vm)),
444 multiply: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a * b, vm)),
445 power: Some(|a, b, c, vm| {
446 if vm.is_none(c) {
447 PyComplex::number_op(a, b, inner_pow, vm)
448 } else {
449 Err(vm.new_value_error(String::from("complex modulo")))
450 }
451 }),
452 negative: Some(|number, vm| {
453 let value = PyComplex::number_downcast(number).value;
454 (-value).to_pyresult(vm)
455 }),
456 positive: Some(|number, vm| {
457 PyComplex::number_downcast_exact(number, vm).to_pyresult(vm)
458 }),
459 absolute: Some(|number, vm| {
460 let value = PyComplex::number_downcast(number).value;
461 value.norm().to_pyresult(vm)
462 }),
463 boolean: Some(|number, _vm| Ok(PyComplex::number_downcast(number).value.is_zero())),
464 true_divide: Some(|a, b, vm| PyComplex::number_op(a, b, inner_div, vm)),
465 ..PyNumberMethods::NOT_IMPLEMENTED
466 };
467 &AS_NUMBER
468 }
469
470 fn clone_exact(zelf: &Py<Self>, vm: &VirtualMachine) -> PyRef<Self> {
471 vm.ctx.new_complex(zelf.value)
472 }
473}
474
475impl Representable for PyComplex {
476 #[inline]
477 fn repr_str(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
478 let Complex64 { re, im } = zelf.value;
481 let mut im_part = if im.fract() == 0.0 {
483 im.to_string()
484 } else {
485 crate::literal::float::to_string(im)
486 };
487 im_part.push('j');
488
489 let re_part = if re == 0.0 {
491 if re.is_sign_positive() {
492 return Ok(im_part);
493 } else {
494 re.to_string()
495 }
496 } else if re.fract() == 0.0 {
497 re.to_string()
498 } else {
499 crate::literal::float::to_string(re)
500 };
501 let mut result = String::with_capacity(
502 re_part.len() + im_part.len() + 2 + im.is_sign_positive() as usize,
503 );
504 result.push('(');
505 result.push_str(&re_part);
506 if im.is_sign_positive() || im.is_nan() {
507 result.push('+');
508 }
509 result.push_str(&im_part);
510 result.push(')');
511 Ok(result)
512 }
513}
514
515impl PyComplex {
516 fn number_op<F, R>(a: &PyObject, b: &PyObject, op: F, vm: &VirtualMachine) -> PyResult
517 where
518 F: FnOnce(Complex64, Complex64, &VirtualMachine) -> R,
519 R: ToPyResult,
520 {
521 if let (Some(a), Some(b)) = (to_op_complex(a, vm)?, to_op_complex(b, vm)?) {
522 op(a, b, vm).to_pyresult(vm)
523 } else {
524 Ok(vm.ctx.not_implemented())
525 }
526 }
527}
528
529#[derive(FromArgs)]
530pub struct ComplexArgs {
531 #[pyarg(any, optional)]
532 real: OptionalArg<PyObjectRef>,
533 #[pyarg(any, optional)]
534 imag: OptionalArg<PyObjectRef>,
535}
536
537fn parse_str(s: &str) -> Option<Complex64> {
538 let s = match s.strip_prefix('(') {
540 None => s,
541 Some(s) => match s.strip_suffix(')') {
542 None => return None,
543 Some(s) => s.trim(),
544 },
545 };
546
547 let value = match s.strip_suffix(|c| c == 'j' || c == 'J') {
548 None => Complex64::new(crate::literal::float::parse_str(s)?, 0.0),
549 Some(mut s) => {
550 let mut real = 0.0;
551 for (i, w) in s.as_bytes().windows(2).enumerate() {
553 if (w[1] == b'+' || w[1] == b'-') && !(w[0] == b'e' || w[0] == b'E') {
554 real = crate::literal::float::parse_str(&s[..=i])?;
555 s = &s[i + 1..];
556 break;
557 }
558 }
559
560 let imag = match s {
561 "" | "+" => 1.0,
563 "-" => -1.0,
565 s => crate::literal::float::parse_str(s)?,
566 };
567
568 Complex64::new(real, imag)
569 }
570 };
571 Some(value)
572}