diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index bdda7cc6c00f5d..9aa84d0fa15718 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1223,10 +1223,6 @@ def _get_slots(cls): def _update_func_cell_for__class__(f, oldcls, newcls): # Returns True if we update a cell, else False. - if f is None: - # f will be None in the case of a property where not all of - # fget, fset, and fdel are used. Nothing to do in that case. - return False try: idx = f.__code__.co_freevars.index("__class__") except ValueError: @@ -1235,13 +1231,57 @@ def _update_func_cell_for__class__(f, oldcls, newcls): # Fix the cell to point to the new class, if it's already pointing # at the old class. I'm not convinced that the "is oldcls" test # is needed, but other than performance can't hurt. - closure = f.__closure__[idx] - if closure.cell_contents is oldcls: - closure.cell_contents = newcls + cell = f.__closure__[idx] + if cell.cell_contents is oldcls: + cell.cell_contents = newcls return True return False +_object_members_values = { + value for name, value in + ( + *inspect.getmembers_static(object), + *inspect.getmembers_static(object()) + ) +} + + +def _is_not_object_member(v): + try: + return v not in _object_members_values + except TypeError: + return True + + +def _find_inner_functions(obj, seen=None, depth=0): + if seen is None: + seen = set() + if id(obj) in seen: + return None + seen.add(id(obj)) + + depth += 1 + # Normally just an inspection of a descriptor object itself should be enough, + # and we should encounter the function as its attribute, + # but in case function was wrapped (e.g. functools.partial was used), + # we want to dive at least one level deeper. + if depth > 2: + return None + + obj_is_type_instance = type in inspect._static_getmro(type(obj)) + for _, value in inspect.getmembers_static(obj, _is_not_object_member): + value_type = type(value) + if value_type is types.MemberDescriptorType and not obj_is_type_instance: + value = value.__get__(obj) + value_type = type(value) + + if value_type is types.FunctionType: + yield inspect.unwrap(value) + else: + yield from _find_inner_functions(value, seen, depth) + + def _create_slots(defined_fields, inherited_slots, field_names, weakref_slot): # The slots for our class. Remove slots from our base classes. Add # '__weakref__' if weakref_slot was given, unless it is already present. @@ -1317,7 +1357,11 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields): # (the newly created one, which we're returning) and not the # original class. We can break out of this loop as soon as we # make an update, since all closures for a class will share a - # given cell. + # given cell. First we try to find a pure function or a property, + # and then fallback to inspecting custom descriptors + # if no pure function or property is found. + + custom_descriptors_to_check = [] for member in newcls.__dict__.values(): # If this is a wrapped function, unwrap it. member = inspect.unwrap(member) @@ -1325,11 +1369,29 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields): if isinstance(member, types.FunctionType): if _update_func_cell_for__class__(member, cls, newcls): break - elif isinstance(member, property): - if (_update_func_cell_for__class__(member.fget, cls, newcls) - or _update_func_cell_for__class__(member.fset, cls, newcls) - or _update_func_cell_for__class__(member.fdel, cls, newcls)): - break + elif isinstance(member, property) and ( + any( + # Unwrap once more in case function + # was wrapped before it became property. + _update_func_cell_for__class__(inspect.unwrap(f), cls, newcls) + for f in (member.fget, member.fset, member.fdel) + if f is not None + ) + ): + break + elif hasattr(member, "__get__") and not inspect.ismemberdescriptor( + member + ): + # We don't want to inspect custom descriptors just yet + # there's still a chance we'll encounter a pure function + # or a property and won't have to use slower recursive search. + custom_descriptors_to_check.append(member) + else: + # Now let's ensure custom descriptors won't be left out. + for descriptor in custom_descriptors_to_check: + for f in _find_inner_functions(descriptor): + if _update_func_cell_for__class__(f, cls, newcls): + break return newcls diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 2984f4261bd2c4..3cb2dc8ce5e0f5 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -13,6 +13,7 @@ import weakref import traceback import unittest +from functools import partial, update_wrapper from unittest.mock import Mock from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol, DefaultDict from typing import get_type_hints @@ -5031,6 +5032,194 @@ def foo(self): A().foo() + def test_wrapped_property(self): + def mydecorator(f): + @wraps(f) + def wrapper(*args, **kwargs): + return f(*args, **kwargs) + return wrapper + + class B: + @property + def foo(self): + return "bar" + + @dataclass(slots=True) + class A(B): + @property + @mydecorator + def foo(self): + return super().foo + + self.assertEqual(A().foo, "bar") + + def test_custom_descriptor(self): + class CustomDescriptor: + def __init__(self, f): + self._f = f + + def __get__(self, instance, owner): + return self._f(instance) + + class B: + def foo(self): + return "bar" + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(cls): + return super().foo() + + self.assertEqual(A().foo, "bar") + + def test_custom_descriptor_wrapped(self): + class CustomDescriptor: + def __init__(self, f): + self._f = update_wrapper(lambda *args, **kwargs: f(*args, **kwargs), f) + + def __get__(self, instance, owner): + return self._f(instance) + + class B: + def foo(self): + return "bar" + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(cls): + return super().foo() + + self.assertEqual(A().foo, "bar") + + def test_custom_nested_descriptor(self): + class CustomFunctionWrapper: + def __init__(self, f): + self._f = f + + def __call__(self, *args, **kwargs): + return self._f(*args, **kwargs) + + class CustomDescriptor: + def __init__(self, f): + self._wrapper = CustomFunctionWrapper(f) + + def __get__(self, instance, owner): + return self._wrapper(instance) + + class B: + def foo(self): + return "bar" + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(cls): + return super().foo() + + self.assertEqual(A().foo, "bar") + + def test_custom_nested_descriptor_with_partial(self): + class CustomDescriptor: + def __init__(self, f): + self._wrapper = partial(f, value="bar") + + def __get__(self, instance, owner): + return self._wrapper(instance) + + class B: + def foo(self, value): + return value + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(self, value): + return super().foo(value) + + self.assertEqual(A().foo, "bar") + + def test_custom_too_nested_descriptor(self): + class UnnecessaryNestedWrapper: + def __init__(self, wrapper): + self._wrapper = wrapper + + def __call__(self, *args, **kwargs): + return self._wrapper(*args, **kwargs) + + class CustomFunctionWrapper: + def __init__(self, f): + self._f = f + + def __call__(self, *args, **kwargs): + return self._f(*args, **kwargs) + + class CustomDescriptor: + def __init__(self, f): + self._wrapper = UnnecessaryNestedWrapper(CustomFunctionWrapper(f)) + + def __get__(self, instance, owner): + return self._wrapper(instance) + + class B: + def foo(self): + return "bar" + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(cls): + return super().foo() + + with self.assertRaises(TypeError) as context: + A().foo + + expected_error_message = ( + 'super(type, obj): obj (instance of A) is not ' + 'an instance or subtype of type (A).' + ) + self.assertEqual(context.exception.args, (expected_error_message,)) + + def test_user_defined_code_execution(self): + class CustomDescriptor: + def __init__(self, f): + self._wrapper = partial(f, value="bar") + + def __get__(self, instance, owner): + return object.__getattribute__(self, "_wrapper")(instance) + + def __getattribute__(self, name): + if name in { + # these are the bare minimum for the feature to work + "__class__", # accessed on `isinstance(value, Field)` + "__wrapped__", # accessed by unwrap + "__get__", # is required for the descriptor protocol + "__dict__", # is accessed by dir() to work + }: + return object.__getattribute__(self, name) + raise RuntimeError(f"Never should be accessed: {name}") + + class B: + def foo(self, value): + return value + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(self, value): + return super().foo(value) + + self.assertEqual(A().foo, "bar") + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(self, value): + return super().foo(value) + + self.assertEqual(A().foo, "bar") + def test_remembered_class(self): # Apply the dataclass decorator manually (not when the class # is created), so that we can keep a reference to the diff --git a/Misc/NEWS.d/next/Library/2024-09-27-19-50-30.gh-issue-90562.HeL_JA.rst b/Misc/NEWS.d/next/Library/2024-09-27-19-50-30.gh-issue-90562.HeL_JA.rst new file mode 100644 index 00000000000000..b0bbdd406f63c7 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-09-27-19-50-30.gh-issue-90562.HeL_JA.rst @@ -0,0 +1,2 @@ +Modify dataclasses to enable zero argument support for ``super()`` when ``slots=True`` is +specified and custom descriptor is used or ``property`` function is wrapped.