rustpython_vm/builtins/
dict.rs

1use super::{
2    set::PySetInner, IterStatus, PositionIterInternal, PyBaseExceptionRef, PyGenericAlias,
3    PyMappingProxy, PySet, PyStr, PyStrRef, PyTupleRef, PyType, PyTypeRef,
4};
5use crate::{
6    atomic_func,
7    builtins::{
8        iter::{builtins_iter, builtins_reversed},
9        type_::PyAttributes,
10        PyTuple,
11    },
12    class::{PyClassDef, PyClassImpl},
13    common::ascii,
14    dictdatatype::{self, DictKey},
15    function::{ArgIterable, KwArgs, OptionalArg, PyArithmeticValue::*, PyComparisonValue},
16    iter::PyExactSizeIterator,
17    protocol::{PyIterIter, PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods},
18    recursion::ReprGuard,
19    types::{
20        AsMapping, AsNumber, AsSequence, Callable, Comparable, Constructor, DefaultConstructor,
21        Initializer, IterNext, Iterable, PyComparisonOp, Representable, SelfIter, Unconstructible,
22    },
23    vm::VirtualMachine,
24    AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, PyResult,
25    TryFromObject,
26};
27use once_cell::sync::Lazy;
28use rustpython_common::lock::PyMutex;
29use std::fmt;
30
31pub type DictContentType = dictdatatype::Dict;
32
33#[pyclass(module = false, name = "dict", unhashable = true, traverse)]
34#[derive(Default)]
35pub struct PyDict {
36    entries: DictContentType,
37}
38pub type PyDictRef = PyRef<PyDict>;
39
40impl fmt::Debug for PyDict {
41    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42        // TODO: implement more detailed, non-recursive Debug formatter
43        f.write_str("dict")
44    }
45}
46
47impl PyPayload for PyDict {
48    fn class(ctx: &Context) -> &'static Py<PyType> {
49        ctx.types.dict_type
50    }
51}
52
53impl PyDict {
54    pub fn new_ref(ctx: &Context) -> PyRef<Self> {
55        PyRef::new_ref(Self::default(), ctx.types.dict_type.to_owned(), None)
56    }
57
58    /// escape hatch to access the underlying data structure directly. prefer adding a method on
59    /// PyDict instead of using this
60    pub(crate) fn _as_dict_inner(&self) -> &DictContentType {
61        &self.entries
62    }
63
64    // Used in update and ior.
65    pub(crate) fn merge_object(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
66        let casted: Result<PyRefExact<PyDict>, _> = other.downcast_exact(vm);
67        let other = match casted {
68            Ok(dict_other) => return self.merge_dict(dict_other.into_pyref(), vm),
69            Err(other) => other,
70        };
71        let dict = &self.entries;
72        if let Some(keys) = vm.get_method(other.clone(), vm.ctx.intern_str("keys")) {
73            let keys = keys?.call((), vm)?.get_iter(vm)?;
74            while let PyIterReturn::Return(key) = keys.next(vm)? {
75                let val = other.get_item(&*key, vm)?;
76                dict.insert(vm, &*key, val)?;
77            }
78        } else {
79            let iter = other.get_iter(vm)?;
80            loop {
81                fn err(vm: &VirtualMachine) -> PyBaseExceptionRef {
82                    vm.new_value_error("Iterator must have exactly two elements".to_owned())
83                }
84                let element = match iter.next(vm)? {
85                    PyIterReturn::Return(obj) => obj,
86                    PyIterReturn::StopIteration(_) => break,
87                };
88                let elem_iter = element.get_iter(vm)?;
89                let key = elem_iter.next(vm)?.into_result().map_err(|_| err(vm))?;
90                let value = elem_iter.next(vm)?.into_result().map_err(|_| err(vm))?;
91                if matches!(elem_iter.next(vm)?, PyIterReturn::Return(_)) {
92                    return Err(err(vm));
93                }
94                dict.insert(vm, &*key, value)?;
95            }
96        }
97        Ok(())
98    }
99
100    fn merge_dict(&self, dict_other: PyDictRef, vm: &VirtualMachine) -> PyResult<()> {
101        let dict = &self.entries;
102        let dict_size = &dict_other.size();
103        for (key, value) in &dict_other {
104            dict.insert(vm, &*key, value)?;
105        }
106        if dict_other.entries.has_changed_size(dict_size) {
107            return Err(vm.new_runtime_error("dict mutated during update".to_owned()));
108        }
109        Ok(())
110    }
111
112    pub fn is_empty(&self) -> bool {
113        self.entries.len() == 0
114    }
115
116    /// Set item variant which can be called with multiple
117    /// key types, such as str to name a notable one.
118    pub(crate) fn inner_setitem<K: DictKey + ?Sized>(
119        &self,
120        key: &K,
121        value: PyObjectRef,
122        vm: &VirtualMachine,
123    ) -> PyResult<()> {
124        self.entries.insert(vm, key, value)
125    }
126
127    pub(crate) fn inner_delitem<K: DictKey + ?Sized>(
128        &self,
129        key: &K,
130        vm: &VirtualMachine,
131    ) -> PyResult<()> {
132        self.entries.delete(vm, key)
133    }
134
135    pub fn get_or_insert(
136        &self,
137        vm: &VirtualMachine,
138        key: PyObjectRef,
139        default: impl FnOnce() -> PyObjectRef,
140    ) -> PyResult {
141        self.entries.setdefault(vm, &*key, default)
142    }
143
144    pub fn from_attributes(attrs: PyAttributes, vm: &VirtualMachine) -> PyResult<Self> {
145        let entries = DictContentType::default();
146
147        for (key, value) in attrs {
148            entries.insert(vm, key, value)?;
149        }
150
151        Ok(Self { entries })
152    }
153
154    pub fn contains_key<K: DictKey + ?Sized>(&self, key: &K, vm: &VirtualMachine) -> bool {
155        self.entries.contains(vm, key).unwrap()
156    }
157
158    pub fn size(&self) -> dictdatatype::DictSize {
159        self.entries.size()
160    }
161}
162
163// Python dict methods:
164#[allow(clippy::len_without_is_empty)]
165#[pyclass(
166    with(
167        Py,
168        PyRef,
169        Constructor,
170        Initializer,
171        Comparable,
172        Iterable,
173        AsSequence,
174        AsNumber,
175        AsMapping,
176        Representable
177    ),
178    flags(BASETYPE)
179)]
180impl PyDict {
181    #[pyclassmethod]
182    fn fromkeys(
183        class: PyTypeRef,
184        iterable: ArgIterable,
185        value: OptionalArg<PyObjectRef>,
186        vm: &VirtualMachine,
187    ) -> PyResult {
188        let value = value.unwrap_or_none(vm);
189        let d = PyType::call(&class, ().into(), vm)?;
190        match d.downcast_exact::<PyDict>(vm) {
191            Ok(pydict) => {
192                for key in iterable.iter(vm)? {
193                    pydict.setitem(key?, value.clone(), vm)?;
194                }
195                Ok(pydict.into_pyref().into())
196            }
197            Err(pyobj) => {
198                for key in iterable.iter(vm)? {
199                    pyobj.set_item(&*key?, value.clone(), vm)?;
200                }
201                Ok(pyobj)
202            }
203        }
204    }
205
206    #[pymethod(magic)]
207    pub fn len(&self) -> usize {
208        self.entries.len()
209    }
210
211    #[pymethod(magic)]
212    fn sizeof(&self) -> usize {
213        std::mem::size_of::<Self>() + self.entries.sizeof()
214    }
215
216    #[pymethod(magic)]
217    fn contains(&self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
218        self.entries.contains(vm, &*key)
219    }
220
221    #[pymethod(magic)]
222    fn delitem(&self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
223        self.inner_delitem(&*key, vm)
224    }
225
226    #[pymethod]
227    pub fn clear(&self) {
228        self.entries.clear()
229    }
230
231    #[pymethod(magic)]
232    fn setitem(&self, key: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
233        self.inner_setitem(&*key, value, vm)
234    }
235
236    #[pymethod]
237    fn get(
238        &self,
239        key: PyObjectRef,
240        default: OptionalArg<PyObjectRef>,
241        vm: &VirtualMachine,
242    ) -> PyResult {
243        match self.entries.get(vm, &*key)? {
244            Some(value) => Ok(value),
245            None => Ok(default.unwrap_or_none(vm)),
246        }
247    }
248
249    #[pymethod]
250    fn setdefault(
251        &self,
252        key: PyObjectRef,
253        default: OptionalArg<PyObjectRef>,
254        vm: &VirtualMachine,
255    ) -> PyResult {
256        self.entries
257            .setdefault(vm, &*key, || default.unwrap_or_none(vm))
258    }
259
260    #[pymethod]
261    pub fn copy(&self) -> PyDict {
262        PyDict {
263            entries: self.entries.clone(),
264        }
265    }
266
267    #[pymethod]
268    fn update(
269        &self,
270        dict_obj: OptionalArg<PyObjectRef>,
271        kwargs: KwArgs,
272        vm: &VirtualMachine,
273    ) -> PyResult<()> {
274        if let OptionalArg::Present(dict_obj) = dict_obj {
275            self.merge_object(dict_obj, vm)?;
276        }
277        for (key, value) in kwargs.into_iter() {
278            self.entries.insert(vm, &key, value)?;
279        }
280        Ok(())
281    }
282
283    #[pymethod(magic)]
284    fn or(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
285        let dicted: Result<PyDictRef, _> = other.downcast();
286        if let Ok(other) = dicted {
287            let self_cp = self.copy();
288            self_cp.merge_dict(other, vm)?;
289            return Ok(self_cp.into_pyobject(vm));
290        }
291        Ok(vm.ctx.not_implemented())
292    }
293
294    #[pymethod]
295    fn pop(
296        &self,
297        key: PyObjectRef,
298        default: OptionalArg<PyObjectRef>,
299        vm: &VirtualMachine,
300    ) -> PyResult {
301        match self.entries.pop(vm, &*key)? {
302            Some(value) => Ok(value),
303            None => default.ok_or_else(|| vm.new_key_error(key)),
304        }
305    }
306
307    #[pymethod]
308    fn popitem(&self, vm: &VirtualMachine) -> PyResult<(PyObjectRef, PyObjectRef)> {
309        let (key, value) = self.entries.pop_back().ok_or_else(|| {
310            let err_msg = vm
311                .ctx
312                .new_str(ascii!("popitem(): dictionary is empty"))
313                .into();
314            vm.new_key_error(err_msg)
315        })?;
316        Ok((key, value))
317    }
318
319    #[pyclassmethod(magic)]
320    fn class_getitem(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
321        PyGenericAlias::new(cls, args, vm)
322    }
323}
324
325#[pyclass]
326impl Py<PyDict> {
327    fn inner_cmp(
328        &self,
329        other: &Py<PyDict>,
330        op: PyComparisonOp,
331        item: bool,
332        vm: &VirtualMachine,
333    ) -> PyResult<PyComparisonValue> {
334        if op == PyComparisonOp::Ne {
335            return Self::inner_cmp(self, other, PyComparisonOp::Eq, item, vm)
336                .map(|x| x.map(|eq| !eq));
337        }
338        if !op.eval_ord(self.len().cmp(&other.len())) {
339            return Ok(Implemented(false));
340        }
341        let (superset, subset) = if self.len() < other.len() {
342            (other, self)
343        } else {
344            (self, other)
345        };
346        for (k, v1) in subset {
347            match superset.get_item_opt(&*k, vm)? {
348                Some(v2) => {
349                    if v1.is(&v2) {
350                        continue;
351                    }
352                    if item && !vm.bool_eq(&v1, &v2)? {
353                        return Ok(Implemented(false));
354                    }
355                }
356                None => {
357                    return Ok(Implemented(false));
358                }
359            }
360        }
361        Ok(Implemented(true))
362    }
363
364    #[pymethod(magic)]
365    #[cfg_attr(feature = "flame-it", flame("PyDictRef"))]
366    fn getitem(&self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult {
367        self.inner_getitem(&*key, vm)
368    }
369}
370
371#[pyclass]
372impl PyRef<PyDict> {
373    #[pymethod]
374    fn keys(self) -> PyDictKeys {
375        PyDictKeys::new(self)
376    }
377
378    #[pymethod]
379    fn values(self) -> PyDictValues {
380        PyDictValues::new(self)
381    }
382
383    #[pymethod]
384    fn items(self) -> PyDictItems {
385        PyDictItems::new(self)
386    }
387
388    #[pymethod(magic)]
389    fn reversed(self) -> PyDictReverseKeyIterator {
390        PyDictReverseKeyIterator::new(self)
391    }
392
393    #[pymethod(magic)]
394    fn ior(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<Self> {
395        self.merge_object(other, vm)?;
396        Ok(self)
397    }
398
399    #[pymethod(magic)]
400    fn ror(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
401        let dicted: Result<PyDictRef, _> = other.downcast();
402        if let Ok(other) = dicted {
403            let other_cp = other.copy();
404            other_cp.merge_dict(self, vm)?;
405            return Ok(other_cp.into_pyobject(vm));
406        }
407        Ok(vm.ctx.not_implemented())
408    }
409}
410
411impl DefaultConstructor for PyDict {}
412
413impl Initializer for PyDict {
414    type Args = (OptionalArg<PyObjectRef>, KwArgs);
415
416    fn init(
417        zelf: PyRef<Self>,
418        (dict_obj, kwargs): Self::Args,
419        vm: &VirtualMachine,
420    ) -> PyResult<()> {
421        zelf.update(dict_obj, kwargs, vm)
422    }
423}
424
425impl AsMapping for PyDict {
426    fn as_mapping() -> &'static PyMappingMethods {
427        static AS_MAPPING: PyMappingMethods = PyMappingMethods {
428            length: atomic_func!(|mapping, _vm| Ok(PyDict::mapping_downcast(mapping).len())),
429            subscript: atomic_func!(|mapping, needle, vm| {
430                PyDict::mapping_downcast(mapping).inner_getitem(needle, vm)
431            }),
432            ass_subscript: atomic_func!(|mapping, needle, value, vm| {
433                let zelf = PyDict::mapping_downcast(mapping);
434                if let Some(value) = value {
435                    zelf.inner_setitem(needle, value, vm)
436                } else {
437                    zelf.inner_delitem(needle, vm)
438                }
439            }),
440        };
441        &AS_MAPPING
442    }
443}
444
445impl AsSequence for PyDict {
446    fn as_sequence() -> &'static PySequenceMethods {
447        static AS_SEQUENCE: Lazy<PySequenceMethods> = Lazy::new(|| PySequenceMethods {
448            contains: atomic_func!(|seq, target, vm| PyDict::sequence_downcast(seq)
449                .entries
450                .contains(vm, target)),
451            ..PySequenceMethods::NOT_IMPLEMENTED
452        });
453        &AS_SEQUENCE
454    }
455}
456
457impl AsNumber for PyDict {
458    fn as_number() -> &'static PyNumberMethods {
459        static AS_NUMBER: PyNumberMethods = PyNumberMethods {
460            or: Some(|a, b, vm| {
461                if let Some(a) = a.downcast_ref::<PyDict>() {
462                    PyDict::or(a, b.to_pyobject(vm), vm)
463                } else {
464                    Ok(vm.ctx.not_implemented())
465                }
466            }),
467            inplace_or: Some(|a, b, vm| {
468                if let Some(a) = a.downcast_ref::<PyDict>() {
469                    a.to_owned().ior(b.to_pyobject(vm), vm).map(|d| d.into())
470                } else {
471                    Ok(vm.ctx.not_implemented())
472                }
473            }),
474            ..PyNumberMethods::NOT_IMPLEMENTED
475        };
476        &AS_NUMBER
477    }
478}
479
480impl Comparable for PyDict {
481    fn cmp(
482        zelf: &Py<Self>,
483        other: &PyObject,
484        op: PyComparisonOp,
485        vm: &VirtualMachine,
486    ) -> PyResult<PyComparisonValue> {
487        op.eq_only(|| {
488            let other = class_or_notimplemented!(Self, other);
489            zelf.inner_cmp(other, PyComparisonOp::Eq, true, vm)
490        })
491    }
492}
493
494impl Iterable for PyDict {
495    fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
496        Ok(PyDictKeyIterator::new(zelf).into_pyobject(vm))
497    }
498}
499
500impl Representable for PyDict {
501    #[inline]
502    fn repr(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyStrRef> {
503        let s = if let Some(_guard) = ReprGuard::enter(vm, zelf.as_object()) {
504            let mut str_parts = Vec::with_capacity(zelf.len());
505            for (key, value) in zelf {
506                let key_repr = &key.repr(vm)?;
507                let value_repr = value.repr(vm)?;
508                str_parts.push(format!("{key_repr}: {value_repr}"));
509            }
510
511            vm.ctx.new_str(format!("{{{}}}", str_parts.join(", ")))
512        } else {
513            vm.ctx.intern_str("{...}").to_owned()
514        };
515        Ok(s)
516    }
517
518    #[cold]
519    fn repr_str(_zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
520        unreachable!("use repr instead")
521    }
522}
523
524impl Py<PyDict> {
525    #[inline]
526    fn exact_dict(&self, vm: &VirtualMachine) -> bool {
527        self.class().is(vm.ctx.types.dict_type)
528    }
529
530    fn missing_opt<K: DictKey + ?Sized>(
531        &self,
532        key: &K,
533        vm: &VirtualMachine,
534    ) -> PyResult<Option<PyObjectRef>> {
535        vm.get_method(self.to_owned().into(), identifier!(vm, __missing__))
536            .map(|methods| methods?.call((key.to_pyobject(vm),), vm))
537            .transpose()
538    }
539
540    #[inline]
541    fn inner_getitem<K: DictKey + ?Sized>(
542        &self,
543        key: &K,
544        vm: &VirtualMachine,
545    ) -> PyResult<PyObjectRef> {
546        if let Some(value) = self.entries.get(vm, key)? {
547            Ok(value)
548        } else if let Some(value) = self.missing_opt(key, vm)? {
549            Ok(value)
550        } else {
551            Err(vm.new_key_error(key.to_pyobject(vm)))
552        }
553    }
554
555    /// Take a python dictionary and convert it to attributes.
556    pub fn to_attributes(&self, vm: &VirtualMachine) -> PyAttributes {
557        let mut attrs = PyAttributes::default();
558        for (key, value) in self {
559            let key: PyRefExact<PyStr> = key.downcast_exact(vm).expect("dict has non-string keys");
560            attrs.insert(vm.ctx.intern_str(key), value);
561        }
562        attrs
563    }
564
565    pub fn get_item_opt<K: DictKey + ?Sized>(
566        &self,
567        key: &K,
568        vm: &VirtualMachine,
569    ) -> PyResult<Option<PyObjectRef>> {
570        if self.exact_dict(vm) {
571            self.entries.get(vm, key)
572            // FIXME: check __missing__?
573        } else {
574            match self.as_object().get_item(key, vm) {
575                Ok(value) => Ok(Some(value)),
576                Err(e) if e.fast_isinstance(vm.ctx.exceptions.key_error) => {
577                    self.missing_opt(key, vm)
578                }
579                Err(e) => Err(e),
580            }
581        }
582    }
583
584    pub fn get_item<K: DictKey + ?Sized>(&self, key: &K, vm: &VirtualMachine) -> PyResult {
585        if self.exact_dict(vm) {
586            self.inner_getitem(key, vm)
587        } else {
588            self.as_object().get_item(key, vm)
589        }
590    }
591
592    pub fn set_item<K: DictKey + ?Sized>(
593        &self,
594        key: &K,
595        value: PyObjectRef,
596        vm: &VirtualMachine,
597    ) -> PyResult<()> {
598        if self.exact_dict(vm) {
599            self.inner_setitem(key, value, vm)
600        } else {
601            self.as_object().set_item(key, value, vm)
602        }
603    }
604
605    pub fn del_item<K: DictKey + ?Sized>(&self, key: &K, vm: &VirtualMachine) -> PyResult<()> {
606        if self.exact_dict(vm) {
607            self.inner_delitem(key, vm)
608        } else {
609            self.as_object().del_item(key, vm)
610        }
611    }
612
613    pub fn get_chain<K: DictKey + ?Sized>(
614        &self,
615        other: &Self,
616        key: &K,
617        vm: &VirtualMachine,
618    ) -> PyResult<Option<PyObjectRef>> {
619        let self_exact = self.exact_dict(vm);
620        let other_exact = other.exact_dict(vm);
621        if self_exact && other_exact {
622            self.entries.get_chain(&other.entries, vm, key)
623        } else if let Some(value) = self.get_item_opt(key, vm)? {
624            Ok(Some(value))
625        } else {
626            other.get_item_opt(key, vm)
627        }
628    }
629}
630
631// Implement IntoIterator so that we can easily iterate dictionaries from rust code.
632impl IntoIterator for PyDictRef {
633    type Item = (PyObjectRef, PyObjectRef);
634    type IntoIter = DictIntoIter;
635
636    fn into_iter(self) -> Self::IntoIter {
637        DictIntoIter::new(self)
638    }
639}
640
641impl<'a> IntoIterator for &'a PyDictRef {
642    type Item = (PyObjectRef, PyObjectRef);
643    type IntoIter = DictIter<'a>;
644
645    fn into_iter(self) -> Self::IntoIter {
646        DictIter::new(self)
647    }
648}
649
650impl<'a> IntoIterator for &'a Py<PyDict> {
651    type Item = (PyObjectRef, PyObjectRef);
652    type IntoIter = DictIter<'a>;
653
654    fn into_iter(self) -> Self::IntoIter {
655        DictIter::new(self)
656    }
657}
658
659impl<'a> IntoIterator for &'a PyDict {
660    type Item = (PyObjectRef, PyObjectRef);
661    type IntoIter = DictIter<'a>;
662
663    fn into_iter(self) -> Self::IntoIter {
664        DictIter::new(self)
665    }
666}
667
668pub struct DictIntoIter {
669    dict: PyDictRef,
670    position: usize,
671}
672
673impl DictIntoIter {
674    pub fn new(dict: PyDictRef) -> DictIntoIter {
675        DictIntoIter { dict, position: 0 }
676    }
677}
678
679impl Iterator for DictIntoIter {
680    type Item = (PyObjectRef, PyObjectRef);
681
682    fn next(&mut self) -> Option<Self::Item> {
683        let (position, key, value) = self.dict.entries.next_entry(self.position)?;
684        self.position = position;
685        Some((key, value))
686    }
687
688    fn size_hint(&self) -> (usize, Option<usize>) {
689        let l = self.len();
690        (l, Some(l))
691    }
692}
693impl ExactSizeIterator for DictIntoIter {
694    fn len(&self) -> usize {
695        self.dict.entries.len_from_entry_index(self.position)
696    }
697}
698
699pub struct DictIter<'a> {
700    dict: &'a PyDict,
701    position: usize,
702}
703
704impl<'a> DictIter<'a> {
705    pub fn new(dict: &'a PyDict) -> Self {
706        DictIter { dict, position: 0 }
707    }
708}
709
710impl Iterator for DictIter<'_> {
711    type Item = (PyObjectRef, PyObjectRef);
712
713    fn next(&mut self) -> Option<Self::Item> {
714        let (position, key, value) = self.dict.entries.next_entry(self.position)?;
715        self.position = position;
716        Some((key, value))
717    }
718
719    fn size_hint(&self) -> (usize, Option<usize>) {
720        let l = self.len();
721        (l, Some(l))
722    }
723}
724impl ExactSizeIterator for DictIter<'_> {
725    fn len(&self) -> usize {
726        self.dict.entries.len_from_entry_index(self.position)
727    }
728}
729
730#[pyclass]
731trait DictView: PyPayload + PyClassDef + Iterable + Representable
732where
733    Self::ReverseIter: PyPayload,
734{
735    type ReverseIter;
736
737    fn dict(&self) -> &PyDictRef;
738    fn item(vm: &VirtualMachine, key: PyObjectRef, value: PyObjectRef) -> PyObjectRef;
739
740    #[pymethod(magic)]
741    fn len(&self) -> usize {
742        self.dict().len()
743    }
744
745    #[pymethod(magic)]
746    fn reversed(&self) -> Self::ReverseIter;
747}
748
749macro_rules! dict_view {
750    ( $name: ident, $iter_name: ident, $reverse_iter_name: ident,
751      $class: ident, $iter_class: ident, $reverse_iter_class: ident,
752      $class_name: literal, $iter_class_name: literal, $reverse_iter_class_name: literal,
753      $result_fn: expr) => {
754        #[pyclass(module = false, name = $class_name)]
755        #[derive(Debug)]
756        pub(crate) struct $name {
757            pub dict: PyDictRef,
758        }
759
760        impl $name {
761            pub fn new(dict: PyDictRef) -> Self {
762                $name { dict }
763            }
764        }
765
766        impl DictView for $name {
767            type ReverseIter = $reverse_iter_name;
768            fn dict(&self) -> &PyDictRef {
769                &self.dict
770            }
771            fn item(vm: &VirtualMachine, key: PyObjectRef, value: PyObjectRef) -> PyObjectRef {
772                #[allow(clippy::redundant_closure_call)]
773                $result_fn(vm, key, value)
774            }
775            fn reversed(&self) -> Self::ReverseIter {
776                $reverse_iter_name::new(self.dict.clone())
777            }
778        }
779
780        impl Iterable for $name {
781            fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
782                Ok($iter_name::new(zelf.dict.clone()).into_pyobject(vm))
783            }
784        }
785
786        impl PyPayload for $name {
787            fn class(ctx: &Context) -> &'static Py<PyType> {
788                ctx.types.$class
789            }
790        }
791
792        impl Representable for $name {
793            #[inline]
794            fn repr(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyStrRef> {
795                let s = if let Some(_guard) = ReprGuard::enter(vm, zelf.as_object()) {
796                    let mut str_parts = Vec::with_capacity(zelf.len());
797                    for (key, value) in zelf.dict().clone() {
798                        let s = &Self::item(vm, key, value).repr(vm)?;
799                        str_parts.push(s.as_str().to_owned());
800                    }
801                    vm.ctx
802                        .new_str(format!("{}([{}])", Self::NAME, str_parts.join(", ")))
803                } else {
804                    vm.ctx.intern_str("{...}").to_owned()
805                };
806                Ok(s)
807            }
808
809            #[cold]
810            fn repr_str(_zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
811                unreachable!("use repr instead")
812            }
813        }
814
815        #[pyclass(module = false, name = $iter_class_name)]
816        #[derive(Debug)]
817        pub(crate) struct $iter_name {
818            pub size: dictdatatype::DictSize,
819            pub internal: PyMutex<PositionIterInternal<PyDictRef>>,
820        }
821
822        impl PyPayload for $iter_name {
823            fn class(ctx: &Context) -> &'static Py<PyType> {
824                ctx.types.$iter_class
825            }
826        }
827
828        #[pyclass(with(Unconstructible, IterNext, Iterable))]
829        impl $iter_name {
830            fn new(dict: PyDictRef) -> Self {
831                $iter_name {
832                    size: dict.size(),
833                    internal: PyMutex::new(PositionIterInternal::new(dict, 0)),
834                }
835            }
836
837            #[pymethod(magic)]
838            fn length_hint(&self) -> usize {
839                self.internal.lock().length_hint(|_| self.size.entries_size)
840            }
841
842            #[allow(clippy::redundant_closure_call)]
843            #[pymethod(magic)]
844            fn reduce(&self, vm: &VirtualMachine) -> PyTupleRef {
845                let iter = builtins_iter(vm).to_owned();
846                let internal = self.internal.lock();
847                let entries = match &internal.status {
848                    IterStatus::Active(dict) => dict
849                        .into_iter()
850                        .map(|(key, value)| ($result_fn)(vm, key, value))
851                        .collect::<Vec<_>>(),
852                    IterStatus::Exhausted => vec![],
853                };
854                vm.new_tuple((iter, (vm.ctx.new_list(entries),)))
855            }
856        }
857        impl Unconstructible for $iter_name {}
858
859        impl SelfIter for $iter_name {}
860        impl IterNext for $iter_name {
861            #[allow(clippy::redundant_closure_call)]
862            fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
863                let mut internal = zelf.internal.lock();
864                let next = if let IterStatus::Active(dict) = &internal.status {
865                    if dict.entries.has_changed_size(&zelf.size) {
866                        internal.status = IterStatus::Exhausted;
867                        return Err(vm.new_runtime_error(
868                            "dictionary changed size during iteration".to_owned(),
869                        ));
870                    }
871                    match dict.entries.next_entry(internal.position) {
872                        Some((position, key, value)) => {
873                            internal.position = position;
874                            PyIterReturn::Return(($result_fn)(vm, key, value))
875                        }
876                        None => {
877                            internal.status = IterStatus::Exhausted;
878                            PyIterReturn::StopIteration(None)
879                        }
880                    }
881                } else {
882                    PyIterReturn::StopIteration(None)
883                };
884                Ok(next)
885            }
886        }
887
888        #[pyclass(module = false, name = $reverse_iter_class_name)]
889        #[derive(Debug)]
890        pub(crate) struct $reverse_iter_name {
891            pub size: dictdatatype::DictSize,
892            internal: PyMutex<PositionIterInternal<PyDictRef>>,
893        }
894
895        impl PyPayload for $reverse_iter_name {
896            fn class(ctx: &Context) -> &'static Py<PyType> {
897                ctx.types.$reverse_iter_class
898            }
899        }
900
901        #[pyclass(with(Unconstructible, IterNext, Iterable))]
902        impl $reverse_iter_name {
903            fn new(dict: PyDictRef) -> Self {
904                let size = dict.size();
905                let position = size.entries_size.saturating_sub(1);
906                $reverse_iter_name {
907                    size,
908                    internal: PyMutex::new(PositionIterInternal::new(dict, position)),
909                }
910            }
911
912            #[allow(clippy::redundant_closure_call)]
913            #[pymethod(magic)]
914            fn reduce(&self, vm: &VirtualMachine) -> PyTupleRef {
915                let iter = builtins_reversed(vm).to_owned();
916                let internal = self.internal.lock();
917                // TODO: entries must be reversed too
918                let entries = match &internal.status {
919                    IterStatus::Active(dict) => dict
920                        .into_iter()
921                        .map(|(key, value)| ($result_fn)(vm, key, value))
922                        .collect::<Vec<_>>(),
923                    IterStatus::Exhausted => vec![],
924                };
925                vm.new_tuple((iter, (vm.ctx.new_list(entries),)))
926            }
927
928            #[pymethod(magic)]
929            fn length_hint(&self) -> usize {
930                self.internal
931                    .lock()
932                    .rev_length_hint(|_| self.size.entries_size)
933            }
934        }
935        impl Unconstructible for $reverse_iter_name {}
936
937        impl SelfIter for $reverse_iter_name {}
938        impl IterNext for $reverse_iter_name {
939            #[allow(clippy::redundant_closure_call)]
940            fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
941                let mut internal = zelf.internal.lock();
942                let next = if let IterStatus::Active(dict) = &internal.status {
943                    if dict.entries.has_changed_size(&zelf.size) {
944                        internal.status = IterStatus::Exhausted;
945                        return Err(vm.new_runtime_error(
946                            "dictionary changed size during iteration".to_owned(),
947                        ));
948                    }
949                    match dict.entries.prev_entry(internal.position) {
950                        Some((position, key, value)) => {
951                            if internal.position == position {
952                                internal.status = IterStatus::Exhausted;
953                            } else {
954                                internal.position = position;
955                            }
956                            PyIterReturn::Return(($result_fn)(vm, key, value))
957                        }
958                        None => {
959                            internal.status = IterStatus::Exhausted;
960                            PyIterReturn::StopIteration(None)
961                        }
962                    }
963                } else {
964                    PyIterReturn::StopIteration(None)
965                };
966                Ok(next)
967            }
968        }
969    };
970}
971
972dict_view! {
973    PyDictKeys,
974    PyDictKeyIterator,
975    PyDictReverseKeyIterator,
976    dict_keys_type,
977    dict_keyiterator_type,
978    dict_reversekeyiterator_type,
979    "dict_keys",
980    "dict_keyiterator",
981    "dict_reversekeyiterator",
982    |_vm: &VirtualMachine, key: PyObjectRef, _value: PyObjectRef| key
983}
984
985dict_view! {
986    PyDictValues,
987    PyDictValueIterator,
988    PyDictReverseValueIterator,
989    dict_values_type,
990    dict_valueiterator_type,
991    dict_reversevalueiterator_type,
992    "dict_values",
993    "dict_valueiterator",
994    "dict_reversevalueiterator",
995    |_vm: &VirtualMachine, _key: PyObjectRef, value: PyObjectRef| value
996}
997
998dict_view! {
999    PyDictItems,
1000    PyDictItemIterator,
1001    PyDictReverseItemIterator,
1002    dict_items_type,
1003    dict_itemiterator_type,
1004    dict_reverseitemiterator_type,
1005    "dict_items",
1006    "dict_itemiterator",
1007    "dict_reverseitemiterator",
1008    |vm: &VirtualMachine, key: PyObjectRef, value: PyObjectRef|
1009        vm.new_tuple((key, value)).into()
1010}
1011
1012// Set operations defined on set-like views of the dictionary.
1013#[pyclass]
1014trait ViewSetOps: DictView {
1015    fn to_set(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<PySetInner> {
1016        let len = zelf.dict().len();
1017        let zelf: PyObjectRef = Self::iter(zelf, vm)?;
1018        let iter = PyIterIter::new(vm, zelf, Some(len));
1019        PySetInner::from_iter(iter, vm)
1020    }
1021
1022    #[pymethod(name = "__rxor__")]
1023    #[pymethod(magic)]
1024    fn xor(zelf: PyRef<Self>, other: ArgIterable, vm: &VirtualMachine) -> PyResult<PySet> {
1025        let zelf = Self::to_set(zelf, vm)?;
1026        let inner = zelf.symmetric_difference(other, vm)?;
1027        Ok(PySet { inner })
1028    }
1029
1030    #[pymethod(name = "__rand__")]
1031    #[pymethod(magic)]
1032    fn and(zelf: PyRef<Self>, other: ArgIterable, vm: &VirtualMachine) -> PyResult<PySet> {
1033        let zelf = Self::to_set(zelf, vm)?;
1034        let inner = zelf.intersection(other, vm)?;
1035        Ok(PySet { inner })
1036    }
1037
1038    #[pymethod(name = "__ror__")]
1039    #[pymethod(magic)]
1040    fn or(zelf: PyRef<Self>, other: ArgIterable, vm: &VirtualMachine) -> PyResult<PySet> {
1041        let zelf = Self::to_set(zelf, vm)?;
1042        let inner = zelf.union(other, vm)?;
1043        Ok(PySet { inner })
1044    }
1045
1046    #[pymethod(magic)]
1047    fn sub(zelf: PyRef<Self>, other: ArgIterable, vm: &VirtualMachine) -> PyResult<PySet> {
1048        let zelf = Self::to_set(zelf, vm)?;
1049        let inner = zelf.difference(other, vm)?;
1050        Ok(PySet { inner })
1051    }
1052
1053    #[pymethod(magic)]
1054    fn rsub(zelf: PyRef<Self>, other: ArgIterable, vm: &VirtualMachine) -> PyResult<PySet> {
1055        let left = PySetInner::from_iter(other.iter(vm)?, vm)?;
1056        let right = ArgIterable::try_from_object(vm, Self::iter(zelf, vm)?)?;
1057        let inner = left.difference(right, vm)?;
1058        Ok(PySet { inner })
1059    }
1060
1061    fn cmp(
1062        zelf: &Py<Self>,
1063        other: &PyObject,
1064        op: PyComparisonOp,
1065        vm: &VirtualMachine,
1066    ) -> PyResult<PyComparisonValue> {
1067        match_class!(match other {
1068            ref dictview @ Self => {
1069                return zelf.dict().inner_cmp(
1070                    dictview.dict(),
1071                    op,
1072                    !zelf.class().is(vm.ctx.types.dict_keys_type),
1073                    vm,
1074                );
1075            }
1076            ref _set @ PySet => {
1077                let inner = Self::to_set(zelf.to_owned(), vm)?;
1078                let zelf_set = PySet { inner }.into_pyobject(vm);
1079                return PySet::cmp(zelf_set.downcast_ref().unwrap(), other, op, vm);
1080            }
1081            ref _dictitems @ PyDictItems => {}
1082            ref _dictkeys @ PyDictKeys => {}
1083            _ => {
1084                return Ok(NotImplemented);
1085            }
1086        });
1087        let lhs: Vec<PyObjectRef> = zelf.as_object().to_owned().try_into_value(vm)?;
1088        let rhs: Vec<PyObjectRef> = other.to_owned().try_into_value(vm)?;
1089        lhs.iter()
1090            .richcompare(rhs.iter(), op, vm)
1091            .map(PyComparisonValue::Implemented)
1092    }
1093
1094    #[pymethod]
1095    fn isdisjoint(zelf: PyRef<Self>, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
1096        // TODO: to_set is an expensive operation. After merging #3316 rewrite implementation using PySequence_Contains.
1097        let zelf = Self::to_set(zelf, vm)?;
1098        let result = zelf.isdisjoint(other, vm)?;
1099        Ok(result)
1100    }
1101}
1102
1103impl ViewSetOps for PyDictKeys {}
1104#[pyclass(with(
1105    DictView,
1106    Unconstructible,
1107    Comparable,
1108    Iterable,
1109    ViewSetOps,
1110    AsSequence,
1111    AsNumber,
1112    Representable
1113))]
1114impl PyDictKeys {
1115    #[pymethod(magic)]
1116    fn contains(zelf: PyObjectRef, key: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
1117        zelf.to_sequence().contains(&key, vm)
1118    }
1119
1120    #[pygetset]
1121    fn mapping(zelf: PyRef<Self>) -> PyMappingProxy {
1122        PyMappingProxy::from(zelf.dict().clone())
1123    }
1124}
1125impl Unconstructible for PyDictKeys {}
1126
1127impl Comparable for PyDictKeys {
1128    fn cmp(
1129        zelf: &Py<Self>,
1130        other: &PyObject,
1131        op: PyComparisonOp,
1132        vm: &VirtualMachine,
1133    ) -> PyResult<PyComparisonValue> {
1134        ViewSetOps::cmp(zelf, other, op, vm)
1135    }
1136}
1137
1138impl AsSequence for PyDictKeys {
1139    fn as_sequence() -> &'static PySequenceMethods {
1140        static AS_SEQUENCE: Lazy<PySequenceMethods> = Lazy::new(|| PySequenceMethods {
1141            length: atomic_func!(|seq, _vm| Ok(PyDictKeys::sequence_downcast(seq).len())),
1142            contains: atomic_func!(|seq, target, vm| {
1143                PyDictKeys::sequence_downcast(seq)
1144                    .dict
1145                    .entries
1146                    .contains(vm, target)
1147            }),
1148            ..PySequenceMethods::NOT_IMPLEMENTED
1149        });
1150        &AS_SEQUENCE
1151    }
1152}
1153
1154impl AsNumber for PyDictKeys {
1155    fn as_number() -> &'static PyNumberMethods {
1156        static AS_NUMBER: PyNumberMethods = PyNumberMethods {
1157            subtract: Some(set_inner_number_subtract),
1158            and: Some(set_inner_number_and),
1159            xor: Some(set_inner_number_xor),
1160            or: Some(set_inner_number_or),
1161            ..PyNumberMethods::NOT_IMPLEMENTED
1162        };
1163        &AS_NUMBER
1164    }
1165}
1166
1167impl ViewSetOps for PyDictItems {}
1168#[pyclass(with(
1169    DictView,
1170    Unconstructible,
1171    Comparable,
1172    Iterable,
1173    ViewSetOps,
1174    AsSequence,
1175    AsNumber,
1176    Representable
1177))]
1178impl PyDictItems {
1179    #[pymethod(magic)]
1180    fn contains(zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
1181        zelf.to_sequence().contains(&needle, vm)
1182    }
1183    #[pygetset]
1184    fn mapping(zelf: PyRef<Self>) -> PyMappingProxy {
1185        PyMappingProxy::from(zelf.dict().clone())
1186    }
1187}
1188impl Unconstructible for PyDictItems {}
1189
1190impl Comparable for PyDictItems {
1191    fn cmp(
1192        zelf: &Py<Self>,
1193        other: &PyObject,
1194        op: PyComparisonOp,
1195        vm: &VirtualMachine,
1196    ) -> PyResult<PyComparisonValue> {
1197        ViewSetOps::cmp(zelf, other, op, vm)
1198    }
1199}
1200
1201impl AsSequence for PyDictItems {
1202    fn as_sequence() -> &'static PySequenceMethods {
1203        static AS_SEQUENCE: Lazy<PySequenceMethods> = Lazy::new(|| PySequenceMethods {
1204            length: atomic_func!(|seq, _vm| Ok(PyDictItems::sequence_downcast(seq).len())),
1205            contains: atomic_func!(|seq, target, vm| {
1206                let needle: &Py<PyTuple> = match target.downcast_ref() {
1207                    Some(needle) => needle,
1208                    None => return Ok(false),
1209                };
1210                if needle.len() != 2 {
1211                    return Ok(false);
1212                }
1213
1214                let zelf = PyDictItems::sequence_downcast(seq);
1215                let key = needle.fast_getitem(0);
1216                if !zelf.dict.contains(key.clone(), vm)? {
1217                    return Ok(false);
1218                }
1219                let value = needle.fast_getitem(1);
1220                let found = zelf.dict().getitem(key, vm)?;
1221                vm.identical_or_equal(&found, &value)
1222            }),
1223            ..PySequenceMethods::NOT_IMPLEMENTED
1224        });
1225        &AS_SEQUENCE
1226    }
1227}
1228
1229impl AsNumber for PyDictItems {
1230    fn as_number() -> &'static PyNumberMethods {
1231        static AS_NUMBER: PyNumberMethods = PyNumberMethods {
1232            subtract: Some(set_inner_number_subtract),
1233            and: Some(set_inner_number_and),
1234            xor: Some(set_inner_number_xor),
1235            or: Some(set_inner_number_or),
1236            ..PyNumberMethods::NOT_IMPLEMENTED
1237        };
1238        &AS_NUMBER
1239    }
1240}
1241
1242#[pyclass(with(DictView, Unconstructible, Iterable, AsSequence, Representable))]
1243impl PyDictValues {
1244    #[pygetset]
1245    fn mapping(zelf: PyRef<Self>) -> PyMappingProxy {
1246        PyMappingProxy::from(zelf.dict().clone())
1247    }
1248}
1249impl Unconstructible for PyDictValues {}
1250
1251impl AsSequence for PyDictValues {
1252    fn as_sequence() -> &'static PySequenceMethods {
1253        static AS_SEQUENCE: Lazy<PySequenceMethods> = Lazy::new(|| PySequenceMethods {
1254            length: atomic_func!(|seq, _vm| Ok(PyDictValues::sequence_downcast(seq).len())),
1255            ..PySequenceMethods::NOT_IMPLEMENTED
1256        });
1257        &AS_SEQUENCE
1258    }
1259}
1260
1261fn set_inner_number_op<F>(a: &PyObject, b: &PyObject, f: F, vm: &VirtualMachine) -> PyResult
1262where
1263    F: FnOnce(PySetInner, ArgIterable) -> PyResult<PySetInner>,
1264{
1265    let a = PySetInner::from_iter(
1266        ArgIterable::try_from_object(vm, a.to_owned())?.iter(vm)?,
1267        vm,
1268    )?;
1269    let b = ArgIterable::try_from_object(vm, b.to_owned())?;
1270    Ok(PySet { inner: f(a, b)? }.into_pyobject(vm))
1271}
1272
1273fn set_inner_number_subtract(a: &PyObject, b: &PyObject, vm: &VirtualMachine) -> PyResult {
1274    set_inner_number_op(a, b, |a, b| a.difference(b, vm), vm)
1275}
1276
1277fn set_inner_number_and(a: &PyObject, b: &PyObject, vm: &VirtualMachine) -> PyResult {
1278    set_inner_number_op(a, b, |a, b| a.intersection(b, vm), vm)
1279}
1280
1281fn set_inner_number_xor(a: &PyObject, b: &PyObject, vm: &VirtualMachine) -> PyResult {
1282    set_inner_number_op(a, b, |a, b| a.symmetric_difference(b, vm), vm)
1283}
1284
1285fn set_inner_number_or(a: &PyObject, b: &PyObject, vm: &VirtualMachine) -> PyResult {
1286    set_inner_number_op(a, b, |a, b| a.union(b, vm), vm)
1287}
1288
1289pub(crate) fn init(context: &Context) {
1290    PyDict::extend_class(context, context.types.dict_type);
1291    PyDictKeys::extend_class(context, context.types.dict_keys_type);
1292    PyDictKeyIterator::extend_class(context, context.types.dict_keyiterator_type);
1293    PyDictReverseKeyIterator::extend_class(context, context.types.dict_reversekeyiterator_type);
1294    PyDictValues::extend_class(context, context.types.dict_values_type);
1295    PyDictValueIterator::extend_class(context, context.types.dict_valueiterator_type);
1296    PyDictReverseValueIterator::extend_class(context, context.types.dict_reversevalueiterator_type);
1297    PyDictItems::extend_class(context, context.types.dict_items_type);
1298    PyDictItemIterator::extend_class(context, context.types.dict_itemiterator_type);
1299    PyDictReverseItemIterator::extend_class(context, context.types.dict_reverseitemiterator_type);
1300}
Morty Proxy This is a proxified and sanitized view of the page, visit original site.