rustpython_vm/builtins/
complex.rs

1use super::{float, PyStr, PyType, PyTypeRef};
2use crate::{
3    class::PyClassImpl,
4    convert::{ToPyObject, ToPyResult},
5    function::{
6        OptionalArg, OptionalOption,
7        PyArithmeticValue::{self, *},
8        PyComparisonValue,
9    },
10    identifier,
11    protocol::PyNumberMethods,
12    stdlib::warnings,
13    types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp, Representable},
14    AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
15};
16use num_complex::Complex64;
17use num_traits::Zero;
18use rustpython_common::hash;
19use std::num::Wrapping;
20
21/// Create a complex number from a real part and an optional imaginary part.
22///
23/// This is equivalent to (real + imag*1j) where imag defaults to 0.
24#[pyclass(module = false, name = "complex")]
25#[derive(Debug, Copy, Clone, PartialEq)]
26pub struct PyComplex {
27    value: Complex64,
28}
29
30impl PyComplex {
31    pub fn to_complex64(self) -> Complex64 {
32        self.value
33    }
34}
35
36impl PyPayload for PyComplex {
37    fn class(ctx: &Context) -> &'static Py<PyType> {
38        ctx.types.complex_type
39    }
40}
41
42impl ToPyObject for Complex64 {
43    fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef {
44        PyComplex::new_ref(self, &vm.ctx).into()
45    }
46}
47
48impl From<Complex64> for PyComplex {
49    fn from(value: Complex64) -> Self {
50        PyComplex { value }
51    }
52}
53
54impl PyObjectRef {
55    /// Tries converting a python object into a complex, returns an option of whether the complex
56    /// and whether the  object was a complex originally or coereced into one
57    pub fn try_complex(&self, vm: &VirtualMachine) -> PyResult<Option<(Complex64, bool)>> {
58        if let Some(complex) = self.payload_if_exact::<PyComplex>(vm) {
59            return Ok(Some((complex.value, true)));
60        }
61        if let Some(method) = vm.get_method(self.clone(), identifier!(vm, __complex__)) {
62            let result = method?.call((), vm)?;
63
64            let ret_class = result.class().to_owned();
65            if let Some(ret) = result.downcast_ref::<PyComplex>() {
66                warnings::warn(
67                    vm.ctx.exceptions.deprecation_warning,
68                    format!(
69                        "__complex__ returned non-complex (type {}).  \
70                    The ability to return an instance of a strict subclass of complex \
71                    is deprecated, and may be removed in a future version of Python.",
72                        ret_class
73                    ),
74                    1,
75                    vm,
76                )?;
77
78                return Ok(Some((ret.value, true)));
79            } else {
80                return match result.payload::<PyComplex>() {
81                    Some(complex_obj) => Ok(Some((complex_obj.value, true))),
82                    None => Err(vm.new_type_error(format!(
83                        "__complex__ returned non-complex (type '{}')",
84                        result.class().name()
85                    ))),
86                };
87            }
88        }
89        // `complex` does not have a `__complex__` by default, so subclasses might not either,
90        // use the actual stored value in this case
91        if let Some(complex) = self.payload_if_subclass::<PyComplex>(vm) {
92            return Ok(Some((complex.value, true)));
93        }
94        if let Some(float) = self.try_float_opt(vm) {
95            return Ok(Some((Complex64::new(float?.to_f64(), 0.0), false)));
96        }
97        Ok(None)
98    }
99}
100
101pub fn init(context: &Context) {
102    PyComplex::extend_class(context, context.types.complex_type);
103}
104
105fn to_op_complex(value: &PyObject, vm: &VirtualMachine) -> PyResult<Option<Complex64>> {
106    let r = if let Some(complex) = value.payload_if_subclass::<PyComplex>(vm) {
107        Some(complex.value)
108    } else {
109        float::to_op_float(value, vm)?.map(|float| Complex64::new(float, 0.0))
110    };
111    Ok(r)
112}
113
114fn inner_div(v1: Complex64, v2: Complex64, vm: &VirtualMachine) -> PyResult<Complex64> {
115    if v2.is_zero() {
116        return Err(vm.new_zero_division_error("complex division by zero".to_owned()));
117    }
118
119    Ok(v1.fdiv(v2))
120}
121
122fn inner_pow(v1: Complex64, v2: Complex64, vm: &VirtualMachine) -> PyResult<Complex64> {
123    if v1.is_zero() {
124        return if v2.re < 0.0 || v2.im != 0.0 {
125            let msg = format!("{v1} cannot be raised to a negative or complex power");
126            Err(vm.new_zero_division_error(msg))
127        } else if v2.is_zero() {
128            Ok(Complex64::new(1.0, 0.0))
129        } else {
130            Ok(Complex64::new(0.0, 0.0))
131        };
132    }
133
134    let ans = v1.powc(v2);
135    if ans.is_infinite() && !(v1.is_infinite() || v2.is_infinite()) {
136        Err(vm.new_overflow_error("complex exponentiation overflow".to_owned()))
137    } else {
138        Ok(ans)
139    }
140}
141
142impl Constructor for PyComplex {
143    type Args = ComplexArgs;
144
145    fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult {
146        let imag_missing = args.imag.is_missing();
147        let (real, real_was_complex) = match args.real {
148            OptionalArg::Missing => (Complex64::new(0.0, 0.0), false),
149            OptionalArg::Present(val) => {
150                let val = if cls.is(vm.ctx.types.complex_type) && imag_missing {
151                    match val.downcast_exact::<PyComplex>(vm) {
152                        Ok(c) => {
153                            return Ok(c.into_pyref().into());
154                        }
155                        Err(val) => val,
156                    }
157                } else {
158                    val
159                };
160
161                if let Some(c) = val.try_complex(vm)? {
162                    c
163                } else if let Some(s) = val.payload_if_subclass::<PyStr>(vm) {
164                    if args.imag.is_present() {
165                        return Err(vm.new_type_error(
166                            "complex() can't take second arg if first is a string".to_owned(),
167                        ));
168                    }
169                    let value = parse_str(s.as_str().trim()).ok_or_else(|| {
170                        vm.new_value_error("complex() arg is a malformed string".to_owned())
171                    })?;
172                    return Self::from(value)
173                        .into_ref_with_type(vm, cls)
174                        .map(Into::into);
175                } else {
176                    return Err(vm.new_type_error(format!(
177                        "complex() first argument must be a string or a number, not '{}'",
178                        val.class().name()
179                    )));
180                }
181            }
182        };
183
184        let (imag, imag_was_complex) = match args.imag {
185            // Copy the imaginary from the real to the real of the imaginary
186            // if an  imaginary argument is not passed in
187            OptionalArg::Missing => (Complex64::new(real.im, 0.0), false),
188            OptionalArg::Present(obj) => {
189                if let Some(c) = obj.try_complex(vm)? {
190                    c
191                } else if obj.class().fast_issubclass(vm.ctx.types.str_type) {
192                    return Err(
193                        vm.new_type_error("complex() second arg can't be a string".to_owned())
194                    );
195                } else {
196                    return Err(vm.new_type_error(format!(
197                        "complex() second argument must be a number, not '{}'",
198                        obj.class().name()
199                    )));
200                }
201            }
202        };
203
204        let final_real = if imag_was_complex {
205            real.re - imag.im
206        } else {
207            real.re
208        };
209
210        let final_imag = if real_was_complex && !imag_missing {
211            imag.re + real.im
212        } else {
213            imag.re
214        };
215        let value = Complex64::new(final_real, final_imag);
216        Self::from(value)
217            .into_ref_with_type(vm, cls)
218            .map(Into::into)
219    }
220}
221
222impl PyComplex {
223    pub fn new_ref(value: Complex64, ctx: &Context) -> PyRef<Self> {
224        PyRef::new_ref(Self::from(value), ctx.types.complex_type.to_owned(), None)
225    }
226
227    pub fn to_complex(&self) -> Complex64 {
228        self.value
229    }
230}
231
232#[pyclass(
233    flags(BASETYPE),
234    with(PyRef, Comparable, Hashable, Constructor, AsNumber, Representable)
235)]
236impl PyComplex {
237    #[pygetset]
238    fn real(&self) -> f64 {
239        self.value.re
240    }
241
242    #[pygetset]
243    fn imag(&self) -> f64 {
244        self.value.im
245    }
246
247    #[pymethod(magic)]
248    fn abs(&self, vm: &VirtualMachine) -> PyResult<f64> {
249        let Complex64 { im, re } = self.value;
250        let is_finite = im.is_finite() && re.is_finite();
251        let abs_result = re.hypot(im);
252        if is_finite && abs_result.is_infinite() {
253            Err(vm.new_overflow_error("absolute value too large".to_string()))
254        } else {
255            Ok(abs_result)
256        }
257    }
258
259    #[inline]
260    fn op<F>(
261        &self,
262        other: PyObjectRef,
263        op: F,
264        vm: &VirtualMachine,
265    ) -> PyResult<PyArithmeticValue<Complex64>>
266    where
267        F: Fn(Complex64, Complex64) -> PyResult<Complex64>,
268    {
269        to_op_complex(&other, vm)?.map_or_else(
270            || Ok(NotImplemented),
271            |other| Ok(Implemented(op(self.value, other)?)),
272        )
273    }
274
275    #[pymethod(name = "__radd__")]
276    #[pymethod(magic)]
277    fn add(
278        &self,
279        other: PyObjectRef,
280        vm: &VirtualMachine,
281    ) -> PyResult<PyArithmeticValue<Complex64>> {
282        self.op(other, |a, b| Ok(a + b), vm)
283    }
284
285    #[pymethod(magic)]
286    fn sub(
287        &self,
288        other: PyObjectRef,
289        vm: &VirtualMachine,
290    ) -> PyResult<PyArithmeticValue<Complex64>> {
291        self.op(other, |a, b| Ok(a - b), vm)
292    }
293
294    #[pymethod(magic)]
295    fn rsub(
296        &self,
297        other: PyObjectRef,
298        vm: &VirtualMachine,
299    ) -> PyResult<PyArithmeticValue<Complex64>> {
300        self.op(other, |a, b| Ok(b - a), vm)
301    }
302
303    #[pymethod]
304    fn conjugate(&self) -> Complex64 {
305        self.value.conj()
306    }
307
308    #[pymethod(name = "__rmul__")]
309    #[pymethod(magic)]
310    fn mul(
311        &self,
312        other: PyObjectRef,
313        vm: &VirtualMachine,
314    ) -> PyResult<PyArithmeticValue<Complex64>> {
315        self.op(other, |a, b| Ok(a * b), vm)
316    }
317
318    #[pymethod(magic)]
319    fn truediv(
320        &self,
321        other: PyObjectRef,
322        vm: &VirtualMachine,
323    ) -> PyResult<PyArithmeticValue<Complex64>> {
324        self.op(other, |a, b| inner_div(a, b, vm), vm)
325    }
326
327    #[pymethod(magic)]
328    fn rtruediv(
329        &self,
330        other: PyObjectRef,
331        vm: &VirtualMachine,
332    ) -> PyResult<PyArithmeticValue<Complex64>> {
333        self.op(other, |a, b| inner_div(b, a, vm), vm)
334    }
335
336    #[pymethod(magic)]
337    fn pos(&self) -> Complex64 {
338        self.value
339    }
340
341    #[pymethod(magic)]
342    fn neg(&self) -> Complex64 {
343        -self.value
344    }
345
346    #[pymethod(magic)]
347    fn pow(
348        &self,
349        other: PyObjectRef,
350        mod_val: OptionalOption<PyObjectRef>,
351        vm: &VirtualMachine,
352    ) -> PyResult<PyArithmeticValue<Complex64>> {
353        if mod_val.flatten().is_some() {
354            Err(vm.new_value_error("complex modulo not allowed".to_owned()))
355        } else {
356            self.op(other, |a, b| inner_pow(a, b, vm), vm)
357        }
358    }
359
360    #[pymethod(magic)]
361    fn rpow(
362        &self,
363        other: PyObjectRef,
364        vm: &VirtualMachine,
365    ) -> PyResult<PyArithmeticValue<Complex64>> {
366        self.op(other, |a, b| inner_pow(b, a, vm), vm)
367    }
368
369    #[pymethod(magic)]
370    fn bool(&self) -> bool {
371        !Complex64::is_zero(&self.value)
372    }
373
374    #[pymethod(magic)]
375    fn getnewargs(&self) -> (f64, f64) {
376        let Complex64 { re, im } = self.value;
377        (re, im)
378    }
379}
380
381#[pyclass]
382impl PyRef<PyComplex> {
383    #[pymethod(magic)]
384    fn complex(self, vm: &VirtualMachine) -> PyRef<PyComplex> {
385        if self.is(vm.ctx.types.complex_type) {
386            self
387        } else {
388            PyComplex::from(self.value).into_ref(&vm.ctx)
389        }
390    }
391}
392
393impl Comparable for PyComplex {
394    fn cmp(
395        zelf: &Py<Self>,
396        other: &PyObject,
397        op: PyComparisonOp,
398        vm: &VirtualMachine,
399    ) -> PyResult<PyComparisonValue> {
400        op.eq_only(|| {
401            let result = if let Some(other) = other.payload_if_subclass::<PyComplex>(vm) {
402                if zelf.value.re.is_nan()
403                    && zelf.value.im.is_nan()
404                    && other.value.re.is_nan()
405                    && other.value.im.is_nan()
406                {
407                    true
408                } else {
409                    zelf.value == other.value
410                }
411            } else {
412                match float::to_op_float(other, vm) {
413                    Ok(Some(other)) => zelf.value == other.into(),
414                    Err(_) => false,
415                    Ok(None) => return Ok(PyComparisonValue::NotImplemented),
416                }
417            };
418            Ok(PyComparisonValue::Implemented(result))
419        })
420    }
421}
422
423impl Hashable for PyComplex {
424    #[inline]
425    fn hash(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<hash::PyHash> {
426        let value = zelf.value;
427
428        let re_hash =
429            hash::hash_float(value.re).unwrap_or_else(|| hash::hash_object_id(zelf.get_id()));
430
431        let im_hash =
432            hash::hash_float(value.im).unwrap_or_else(|| hash::hash_object_id(zelf.get_id()));
433
434        let Wrapping(ret) = Wrapping(re_hash) + Wrapping(im_hash) * Wrapping(hash::IMAG);
435        Ok(hash::fix_sentinel(ret))
436    }
437}
438
439impl AsNumber for PyComplex {
440    fn as_number() -> &'static PyNumberMethods {
441        static AS_NUMBER: PyNumberMethods = PyNumberMethods {
442            add: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a + b, vm)),
443            subtract: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a - b, vm)),
444            multiply: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a * b, vm)),
445            power: Some(|a, b, c, vm| {
446                if vm.is_none(c) {
447                    PyComplex::number_op(a, b, inner_pow, vm)
448                } else {
449                    Err(vm.new_value_error(String::from("complex modulo")))
450                }
451            }),
452            negative: Some(|number, vm| {
453                let value = PyComplex::number_downcast(number).value;
454                (-value).to_pyresult(vm)
455            }),
456            positive: Some(|number, vm| {
457                PyComplex::number_downcast_exact(number, vm).to_pyresult(vm)
458            }),
459            absolute: Some(|number, vm| {
460                let value = PyComplex::number_downcast(number).value;
461                value.norm().to_pyresult(vm)
462            }),
463            boolean: Some(|number, _vm| Ok(PyComplex::number_downcast(number).value.is_zero())),
464            true_divide: Some(|a, b, vm| PyComplex::number_op(a, b, inner_div, vm)),
465            ..PyNumberMethods::NOT_IMPLEMENTED
466        };
467        &AS_NUMBER
468    }
469
470    fn clone_exact(zelf: &Py<Self>, vm: &VirtualMachine) -> PyRef<Self> {
471        vm.ctx.new_complex(zelf.value)
472    }
473}
474
475impl Representable for PyComplex {
476    #[inline]
477    fn repr_str(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
478        // TODO: when you fix this, move it to rustpython_common::complex::repr and update
479        //       ast/src/unparse.rs + impl Display for Constant in ast/src/constant.rs
480        let Complex64 { re, im } = zelf.value;
481        // integer => drop ., fractional => float_ops
482        let mut im_part = if im.fract() == 0.0 {
483            im.to_string()
484        } else {
485            crate::literal::float::to_string(im)
486        };
487        im_part.push('j');
488
489        // positive empty => return im_part, integer => drop ., fractional => float_ops
490        let re_part = if re == 0.0 {
491            if re.is_sign_positive() {
492                return Ok(im_part);
493            } else {
494                re.to_string()
495            }
496        } else if re.fract() == 0.0 {
497            re.to_string()
498        } else {
499            crate::literal::float::to_string(re)
500        };
501        let mut result = String::with_capacity(
502            re_part.len() + im_part.len() + 2 + im.is_sign_positive() as usize,
503        );
504        result.push('(');
505        result.push_str(&re_part);
506        if im.is_sign_positive() || im.is_nan() {
507            result.push('+');
508        }
509        result.push_str(&im_part);
510        result.push(')');
511        Ok(result)
512    }
513}
514
515impl PyComplex {
516    fn number_op<F, R>(a: &PyObject, b: &PyObject, op: F, vm: &VirtualMachine) -> PyResult
517    where
518        F: FnOnce(Complex64, Complex64, &VirtualMachine) -> R,
519        R: ToPyResult,
520    {
521        if let (Some(a), Some(b)) = (to_op_complex(a, vm)?, to_op_complex(b, vm)?) {
522            op(a, b, vm).to_pyresult(vm)
523        } else {
524            Ok(vm.ctx.not_implemented())
525        }
526    }
527}
528
529#[derive(FromArgs)]
530pub struct ComplexArgs {
531    #[pyarg(any, optional)]
532    real: OptionalArg<PyObjectRef>,
533    #[pyarg(any, optional)]
534    imag: OptionalArg<PyObjectRef>,
535}
536
537fn parse_str(s: &str) -> Option<Complex64> {
538    // Handle parentheses
539    let s = match s.strip_prefix('(') {
540        None => s,
541        Some(s) => match s.strip_suffix(')') {
542            None => return None,
543            Some(s) => s.trim(),
544        },
545    };
546
547    let value = match s.strip_suffix(|c| c == 'j' || c == 'J') {
548        None => Complex64::new(crate::literal::float::parse_str(s)?, 0.0),
549        Some(mut s) => {
550            let mut real = 0.0;
551            // Find the central +/- operator. If it exists, parse the real part.
552            for (i, w) in s.as_bytes().windows(2).enumerate() {
553                if (w[1] == b'+' || w[1] == b'-') && !(w[0] == b'e' || w[0] == b'E') {
554                    real = crate::literal::float::parse_str(&s[..=i])?;
555                    s = &s[i + 1..];
556                    break;
557                }
558            }
559
560            let imag = match s {
561                // "j", "+j"
562                "" | "+" => 1.0,
563                // "-j"
564                "-" => -1.0,
565                s => crate::literal::float::parse_str(s)?,
566            };
567
568            Complex64::new(real, imag)
569        }
570    };
571    Some(value)
572}
Morty Proxy This is a proxified and sanitized view of the page, visit original site.