diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index e1687a117d..2bfeea515b 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -6,6 +6,7 @@ import keyword import builtins import functools +import itertools import abc import _thread from types import FunctionType, GenericAlias @@ -222,6 +223,26 @@ def __repr__(self): # https://bugs.python.org/issue33453 for details. _MODULE_IDENTIFIER_RE = re.compile(r'^(?:\s*(\w+)\s*\.)?\s*(\w+)') +# This function's logic is copied from "recursive_repr" function in +# reprlib module to avoid dependency. +def _recursive_repr(user_function): + # Decorator to make a repr function return "..." for a recursive + # call. + repr_running = set() + + @functools.wraps(user_function) + def wrapper(self): + key = id(self), _thread.get_ident() + if key in repr_running: + return '...' + repr_running.add(key) + try: + result = user_function(self) + finally: + repr_running.discard(key) + return result + return wrapper + class InitVar: __slots__ = ('type', ) @@ -229,7 +250,7 @@ def __init__(self, type): self.type = type def __repr__(self): - if isinstance(self.type, type) and not isinstance(self.type, GenericAlias): + if isinstance(self.type, type): type_name = self.type.__name__ else: # typing objects, e.g. List[int] @@ -279,6 +300,7 @@ def __init__(self, default, default_factory, init, repr, hash, compare, self.kw_only = kw_only self._field_type = None + @_recursive_repr def __repr__(self): return ('Field(' f'name={self.name!r},' @@ -297,7 +319,7 @@ def __repr__(self): # This is used to support the PEP 487 __set_name__ protocol in the # case where we're using a field that contains a descriptor as a # default value. For details on __set_name__, see - # https://www.python.org/dev/peps/pep-0487/#implementation-details. + # https://peps.python.org/pep-0487/#implementation-details. # # Note that in _process_class, this Field object is overwritten # with the default value, so the end result is a descriptor that @@ -388,27 +410,6 @@ def _tuple_str(obj_name, fields): return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)' -# This function's logic is copied from "recursive_repr" function in -# reprlib module to avoid dependency. -def _recursive_repr(user_function): - # Decorator to make a repr function return "..." for a recursive - # call. - repr_running = set() - - @functools.wraps(user_function) - def wrapper(self): - key = id(self), _thread.get_ident() - if key in repr_running: - return '...' - repr_running.add(key) - try: - result = user_function(self) - finally: - repr_running.discard(key) - return result - return wrapper - - def _create_fn(name, args, body, *, globals=None, locals=None, return_type=MISSING): # Note that we may mutate locals. Callers beware! @@ -807,8 +808,10 @@ def _get_field(cls, a_name, a_type, default_kw_only): raise TypeError(f'field {f.name} is a ClassVar but specifies ' 'kw_only') - # For real fields, disallow mutable defaults for known types. - if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)): + # For real fields, disallow mutable defaults. Use unhashable as a proxy + # indicator for mutability. Read the __hash__ attribute from the class, + # not the instance. + if f._field_type is _FIELD and f.default.__class__.__hash__ is None: raise ValueError(f'mutable default {type(f.default)} for field ' f'{f.name} is not allowed: use default_factory') @@ -879,7 +882,7 @@ def _hash_exception(cls, fields, globals): def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, - match_args, kw_only, slots): + match_args, kw_only, slots, weakref_slot): # Now that dicts retain insertion order, there's no reason to use # an ordered dict. I am leveraging that ordering here, because # derived class fields overwrite base class fields, but the order @@ -1097,8 +1100,11 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, _set_new_attribute(cls, '__match_args__', tuple(f.name for f in std_init_fields)) + # It's an error to specify weakref_slot if slots is False. + if weakref_slot and not slots: + raise TypeError('weakref_slot is True but slots is False') if slots: - cls = _add_slots(cls, frozen) + cls = _add_slots(cls, frozen, weakref_slot) abc.update_abstractmethods(cls) @@ -1106,7 +1112,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, # _dataclass_getstate and _dataclass_setstate are needed for pickling frozen -# classes with slots. These could be slighly more performant if we generated +# classes with slots. These could be slightly more performant if we generated # the code instead of iterating over fields. But that can be a project for # another day, if performance becomes an issue. def _dataclass_getstate(self): @@ -1119,7 +1125,21 @@ def _dataclass_setstate(self, state): object.__setattr__(self, field.name, value) -def _add_slots(cls, is_frozen): +def _get_slots(cls): + match cls.__dict__.get('__slots__'): + case None: + return + case str(slot): + yield slot + # Slots may be any iterable, but we cannot handle an iterator + # because it will already be (partially) consumed. + case iterable if not hasattr(iterable, '__next__'): + yield from iterable + case _: + raise TypeError(f"Slots of '{cls.__name__}' cannot be determined") + + +def _add_slots(cls, is_frozen, weakref_slot): # Need to create a new class, since we can't set __slots__ # after a class has been created. @@ -1130,7 +1150,23 @@ def _add_slots(cls, is_frozen): # Create a new dict for our new class. cls_dict = dict(cls.__dict__) field_names = tuple(f.name for f in fields(cls)) - cls_dict['__slots__'] = field_names + # Make sure slots don't overlap with those in base classes. + inherited_slots = set( + itertools.chain.from_iterable(map(_get_slots, cls.__mro__[1:-1])) + ) + # The slots for our class. Remove slots from our base classes. Add + # '__weakref__' if weakref_slot was given, unless it is already present. + cls_dict["__slots__"] = tuple( + itertools.filterfalse( + inherited_slots.__contains__, + itertools.chain( + # gh-93521: '__weakref__' also needs to be filtered out if + # already present in inherited_slots + field_names, ('__weakref__',) if weakref_slot else () + ) + ), + ) + for field_name in field_names: # Remove our attributes, if present. They'll still be # available in _MARKER. @@ -1155,25 +1191,25 @@ def _add_slots(cls, is_frozen): def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, - kw_only=False, slots=False): - """Returns the same class as was passed in, with dunder methods - added based on the fields defined in the class. + kw_only=False, slots=False, weakref_slot=False): + """Add dunder methods based on the fields defined in the class. Examines PEP 526 __annotations__ to determine fields. - If init is true, an __init__() method is added to the class. If - repr is true, a __repr__() method is added. If order is true, rich + If init is true, an __init__() method is added to the class. If repr + is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a - __hash__() method function is added. If frozen is true, fields may - not be assigned to after instance creation. If match_args is true, - the __match_args__ tuple is added. If kw_only is true, then by - default all fields are keyword-only. If slots is true, an - __slots__ attribute is added. + __hash__() method is added. If frozen is true, fields may not be + assigned to after instance creation. If match_args is true, the + __match_args__ tuple is added. If kw_only is true, then by default + all fields are keyword-only. If slots is true, a new class with a + __slots__ attribute is returned. """ def wrap(cls): return _process_class(cls, init, repr, eq, order, unsafe_hash, - frozen, match_args, kw_only, slots) + frozen, match_args, kw_only, slots, + weakref_slot) # See if we're being called as @dataclass or @dataclass(). if cls is None: @@ -1210,7 +1246,7 @@ def _is_dataclass_instance(obj): def is_dataclass(obj): """Returns True if obj is a dataclass or an instance of a dataclass.""" - cls = obj if isinstance(obj, type) and not isinstance(obj, GenericAlias) else type(obj) + cls = obj if isinstance(obj, type) else type(obj) return hasattr(cls, _FIELDS) @@ -1218,7 +1254,7 @@ def asdict(obj, *, dict_factory=dict): """Return the fields of a dataclass instance as a new dictionary mapping field names to field values. - Example usage: + Example usage:: @dataclass class C: @@ -1289,8 +1325,8 @@ class C: x: int y: int - c = C(1, 2) - assert astuple(c) == (1, 2) + c = C(1, 2) + assert astuple(c) == (1, 2) If given, 'tuple_factory' will be used instead of built-in tuple. The function applies recursively to field values that are @@ -1332,17 +1368,18 @@ def _astuple_inner(obj, tuple_factory): def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, repr=True, eq=True, order=False, unsafe_hash=False, - frozen=False, match_args=True, kw_only=False, slots=False): + frozen=False, match_args=True, kw_only=False, slots=False, + weakref_slot=False): """Return a new dynamically created dataclass. The dataclass name will be 'cls_name'. 'fields' is an iterable of either (name), (name, type) or (name, type, Field) objects. If type is omitted, use the string 'typing.Any'. Field objects are created by - the equivalent of calling 'field(name, type [, Field-info])'. + the equivalent of calling 'field(name, type [, Field-info])'.:: C = make_dataclass('C', ['x', ('y', int), ('z', int, field(init=False))], bases=(Base,)) - is equivalent to: + is equivalent to:: @dataclass class C(Base): @@ -1399,13 +1436,14 @@ def exec_body_callback(ns): # Apply the normal decorator. return dataclass(cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen, - match_args=match_args, kw_only=kw_only, slots=slots) + match_args=match_args, kw_only=kw_only, slots=slots, + weakref_slot=weakref_slot) def replace(obj, /, **changes): """Return a new object replacing specified fields with new values. - This is especially useful for frozen classes. Example usage: + This is especially useful for frozen classes. Example usage:: @dataclass(frozen=True) class C: @@ -1415,7 +1453,7 @@ class C: c = C(1, 2) c1 = replace(c, x=3) assert c1.x == 3 and c1.y == 2 - """ + """ # We're going to mutate 'changes', but that's okay because it's a # new dict, even if called with 'replace(obj, **my_changes)'. diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py index 6bd9d37fad..b94a071c4e 100644 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses.py @@ -9,6 +9,7 @@ import inspect import builtins import types +import weakref import unittest from unittest.mock import Mock from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol @@ -67,6 +68,24 @@ def test_field_repr(self): self.assertEqual(repr_output, expected_output) + def test_field_recursive_repr(self): + rec_field = field() + rec_field.type = rec_field + rec_field.name = "id" + repr_output = repr(rec_field) + + self.assertIn(",type=...,", repr_output) + + def test_recursive_annotation(self): + class C: + pass + + @dataclass + class D: + C: C = field() + + self.assertIn(",type=...,", repr(D.__dataclass_fields__["C"])) + def test_named_init_params(self): @dataclass class C: @@ -230,6 +249,14 @@ class C: c = C('foo') self.assertEqual(c.object, 'foo') + def test_field_named_BUILTINS_frozen(self): + # gh-96151 + @dataclass(frozen=True) + class C: + BUILTINS: int + c = C(5) + self.assertEqual(c.BUILTINS, 5) + def test_field_named_like_builtin(self): # Attribute names can shadow built-in names # since code generation is used. @@ -501,6 +528,32 @@ class C: self.assertNotEqual(C(3), C(4, 10)) self.assertNotEqual(C(3, 10), C(4, 10)) + def test_no_unhashable_default(self): + # See bpo-44674. + class Unhashable: + __hash__ = None + + unhashable_re = 'mutable default .* for field a is not allowed' + with self.assertRaisesRegex(ValueError, unhashable_re): + @dataclass + class A: + a: dict = {} + + with self.assertRaisesRegex(ValueError, unhashable_re): + @dataclass + class A: + a: Any = Unhashable() + + # Make sure that the machinery looking for hashability is using the + # class's __hash__, not the instance's __hash__. + with self.assertRaisesRegex(ValueError, unhashable_re): + unhashable = Unhashable() + # This shouldn't make the variable hashable. + unhashable.__hash__ = lambda: 0 + @dataclass + class A: + a: Any = unhashable + def test_hash_field_rules(self): # Test all 6 cases of: # hash=True/False/None @@ -990,6 +1043,65 @@ def __post_init__(cls): self.assertEqual((c.x, c.y), (3, 4)) self.assertTrue(C.flag) + def test_post_init_not_auto_added(self): + # See bpo-46757, which had proposed always adding __post_init__. As + # Raymond Hettinger pointed out, that would be a breaking change. So, + # add a test to make sure that the current behavior doesn't change. + + @dataclass + class A0: + pass + + @dataclass + class B0: + b_called: bool = False + def __post_init__(self): + self.b_called = True + + @dataclass + class C0(A0, B0): + c_called: bool = False + def __post_init__(self): + super().__post_init__() + self.c_called = True + + # Since A0 has no __post_init__, and one wasn't automatically added + # (because that's the rule: it's never added by @dataclass, it's only + # the class author that can add it), then B0.__post_init__ is called. + # Verify that. + c = C0() + self.assertTrue(c.b_called) + self.assertTrue(c.c_called) + + ###################################### + # Now, the same thing, except A1 defines __post_init__. + @dataclass + class A1: + def __post_init__(self): + pass + + @dataclass + class B1: + b_called: bool = False + def __post_init__(self): + self.b_called = True + + @dataclass + class C1(A1, B1): + c_called: bool = False + def __post_init__(self): + super().__post_init__() + self.c_called = True + + # This time, B1.__post_init__ isn't being called. This mimics what + # would happen if A1.__post_init__ had been automatically added, + # instead of manually added as we see here. This test isn't really + # needed, but I'm including it just to demonstrate the changed + # behavior when A1 does define __post_init__. + c = C1() + self.assertFalse(c.b_called) + self.assertTrue(c.c_called) + def test_class_var(self): # Make sure ClassVars are ignored in __init__, __repr__, etc. @dataclass @@ -2135,12 +2247,12 @@ class C(B): self.assertEqual(c.z, 100) def test_no_init(self): - dataclass(init=False) + @dataclass(init=False) class C: i: int = 0 self.assertEqual(C().i, 0) - dataclass(init=False) + @dataclass(init=False) class C: i: int = 2 def __init__(self): @@ -2851,23 +2963,58 @@ class C: x: int def test_generated_slots_value(self): - @dataclass(slots=True) - class Base: - x: int - self.assertEqual(Base.__slots__, ('x',)) + class Root: + __slots__ = {'x'} + + class Root2(Root): + __slots__ = {'k': '...', 'j': ''} + + class Root3(Root2): + __slots__ = ['h'] + + class Root4(Root3): + __slots__ = 'aa' @dataclass(slots=True) - class Delivered(Base): + class Base(Root4): y: int + j: str + h: str - self.assertEqual(Delivered.__slots__, ('x', 'y')) + self.assertEqual(Base.__slots__, ('y', )) + + @dataclass(slots=True) + class Derived(Base): + aa: float + x: str + z: int + k: str + h: str + + self.assertEqual(Derived.__slots__, ('z', )) @dataclass - class AnotherDelivered(Base): + class AnotherDerived(Base): z: int - self.assertTrue('__slots__' not in AnotherDelivered.__dict__) + self.assertNotIn('__slots__', AnotherDerived.__dict__) + + def test_cant_inherit_from_iterator_slots(self): + + class Root: + __slots__ = iter(['a']) + + class Root2(Root): + __slots__ = ('b', ) + + with self.assertRaisesRegex( + TypeError, + "^Slots of 'Root' cannot be determined" + ): + @dataclass(slots=True) + class C(Root2): + x: int def test_returns_new_class(self): class A: @@ -2928,6 +3075,125 @@ class A: self.assertEqual(obj.a, 'a') self.assertEqual(obj.b, 'b') + def test_slots_no_weakref(self): + @dataclass(slots=True) + class A: + # No weakref. + pass + + self.assertNotIn("__weakref__", A.__slots__) + a = A() + with self.assertRaisesRegex(TypeError, + "cannot create weak reference"): + weakref.ref(a) + + def test_slots_weakref(self): + @dataclass(slots=True, weakref_slot=True) + class A: + a: int + + self.assertIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + def test_slots_weakref_base_str(self): + class Base: + __slots__ = '__weakref__' + + @dataclass(slots=True) + class A(Base): + a: int + + # __weakref__ is in the base class, not A. But an A is still weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + def test_slots_weakref_base_tuple(self): + # Same as test_slots_weakref_base, but use a tuple instead of a string + # in the base class. + class Base: + __slots__ = ('__weakref__',) + + @dataclass(slots=True) + class A(Base): + a: int + + # __weakref__ is in the base class, not A. But an A is still + # weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + def test_weakref_slot_without_slot(self): + with self.assertRaisesRegex(TypeError, + "weakref_slot is True but slots is False"): + @dataclass(weakref_slot=True) + class A: + a: int + + def test_weakref_slot_make_dataclass(self): + A = make_dataclass('A', [('a', int),], slots=True, weakref_slot=True) + self.assertIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + # And make sure if raises if slots=True is not given. + with self.assertRaisesRegex(TypeError, + "weakref_slot is True but slots is False"): + B = make_dataclass('B', [('a', int),], weakref_slot=True) + + def test_weakref_slot_subclass_weakref_slot(self): + @dataclass(slots=True, weakref_slot=True) + class Base: + field: int + + # A *can* also specify weakref_slot=True if it wants to (gh-93521) + @dataclass(slots=True, weakref_slot=True) + class A(Base): + ... + + # __weakref__ is in the base class, not A. But an instance of A + # is still weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + def test_weakref_slot_subclass_no_weakref_slot(self): + @dataclass(slots=True, weakref_slot=True) + class Base: + field: int + + @dataclass(slots=True) + class A(Base): + ... + + # __weakref__ is in the base class, not A. Even though A doesn't + # specify weakref_slot, it should still be weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + def test_weakref_slot_normal_base_weakref_slot(self): + class Base: + __slots__ = ('__weakref__',) + + @dataclass(slots=True, weakref_slot=True) + class A(Base): + field: int + + # __weakref__ is in the base class, not A. But an instance of + # A is still weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + class TestDescriptors(unittest.TestCase): def test_set_name(self): # See bpo-33141.