From 3b65acf8d66e7e42ad25746976b55dc0bd7d71d8 Mon Sep 17 00:00:00 2001 From: Arseny Boykov <36469655+Bobronium@users.noreply.github.com> Date: Fri, 27 Sep 2024 19:56:36 +0200 Subject: [PATCH 01/17] Improve zero argument support for `super()` in dataclasses --- Lib/dataclasses.py | 64 +++++++++++++++---- Lib/test/test_dataclasses/__init__.py | 41 ++++++++++++ ...4-09-27-19-50-30.gh-issue-90562.HeL_JA.rst | 2 + 3 files changed, 94 insertions(+), 13 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2024-09-27-19-50-30.gh-issue-90562.HeL_JA.rst diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index bdda7cc6c00f5d..5d2088d7cf9b6c 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1222,11 +1222,6 @@ def _get_slots(cls): def _update_func_cell_for__class__(f, oldcls, newcls): - # Returns True if we update a cell, else False. - if f is None: - # f will be None in the case of a property where not all of - # fget, fset, and fdel are used. Nothing to do in that case. - return False try: idx = f.__code__.co_freevars.index("__class__") except ValueError: @@ -1235,13 +1230,36 @@ def _update_func_cell_for__class__(f, oldcls, newcls): # Fix the cell to point to the new class, if it's already pointing # at the old class. I'm not convinced that the "is oldcls" test # is needed, but other than performance can't hurt. - closure = f.__closure__[idx] - if closure.cell_contents is oldcls: - closure.cell_contents = newcls + cell = f.__closure__[idx] + if cell.cell_contents is oldcls: + cell.cell_contents = newcls return True return False +def _find_inner_functions(obj, _seen=None, _depth=0): + if _seen is None: + _seen = set() + if id(obj) in _seen: + return None + _seen.add(id(obj)) + + _depth += 1 + if _depth > 2: + return None + + obj = inspect.unwrap(obj) + + for attr in dir(obj): + value = getattr(obj, attr, None) + if value is None: + continue + if isinstance(obj, types.FunctionType): + yield obj + return + yield from _find_inner_functions(value, _seen, _depth) + + def _create_slots(defined_fields, inherited_slots, field_names, weakref_slot): # The slots for our class. Remove slots from our base classes. Add # '__weakref__' if weakref_slot was given, unless it is already present. @@ -1317,7 +1335,10 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields): # (the newly created one, which we're returning) and not the # original class. We can break out of this loop as soon as we # make an update, since all closures for a class will share a - # given cell. + # given cell. First we try to find a pure function/properties, + # and then fallback to inspecting custom descriptors. + + custom_descriptors_to_check = [] for member in newcls.__dict__.values(): # If this is a wrapped function, unwrap it. member = inspect.unwrap(member) @@ -1326,10 +1347,27 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields): if _update_func_cell_for__class__(member, cls, newcls): break elif isinstance(member, property): - if (_update_func_cell_for__class__(member.fget, cls, newcls) - or _update_func_cell_for__class__(member.fset, cls, newcls) - or _update_func_cell_for__class__(member.fdel, cls, newcls)): - break + for f in member.fget, member.fset, member.fdel: + if f is None: + continue + # unwrap once more in case function + # was wrapped before it became property + f = inspect.unwrap(f) + if _update_func_cell_for__class__(f, cls, newcls): + break + elif hasattr(member, "__get__") and not inspect.ismemberdescriptor( + member + ): + # we don't want to inspect custom descriptors just yet + # there's still a chance we'll encounter a pure function + # or a property + custom_descriptors_to_check.append(member) + else: + # now let's ensure custom descriptors won't be left out + for descriptor in custom_descriptors_to_check: + for f in _find_inner_functions(descriptor): + if _update_func_cell_for__class__(f, cls, newcls): + break return newcls diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 2984f4261bd2c4..8611cb3dbf7f87 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -5031,6 +5031,47 @@ def foo(self): A().foo() + def test_wrapped_property(self): + def mydecorator(f): + @wraps(f) + def wrapper(*args, **kwargs): + return f(*args, **kwargs) + return wrapper + + class B: + @property + def foo(self): + return "bar" + + @dataclass(slots=True) + class A(B): + @property + @mydecorator + def foo(self): + return super().foo + + self.assertEqual(A().foo, "bar") + + def test_custom_descriptor(self): + class CustomDescriptor: + def __init__(self, f): + self._f = f + + def __get__(self, instance, owner): + return self._f(instance) + + class B: + def foo(self): + return "bar" + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(cls): + return super().foo() + + self.assertEqual(A().foo, "bar") + def test_remembered_class(self): # Apply the dataclass decorator manually (not when the class # is created), so that we can keep a reference to the diff --git a/Misc/NEWS.d/next/Library/2024-09-27-19-50-30.gh-issue-90562.HeL_JA.rst b/Misc/NEWS.d/next/Library/2024-09-27-19-50-30.gh-issue-90562.HeL_JA.rst new file mode 100644 index 00000000000000..56f2e1109ce5f5 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-09-27-19-50-30.gh-issue-90562.HeL_JA.rst @@ -0,0 +1,2 @@ +Modify dataclasses to enable zero argument support for ``super()`` when ``slots=True`` is +specified and custom descriptor is used or `property` function is wrapped. From efb60954f1dfec5855a53b2369ef6b4232782b86 Mon Sep 17 00:00:00 2001 From: Arseny Boykov <36469655+Bobronium@users.noreply.github.com> Date: Sat, 28 Sep 2024 00:41:47 +0200 Subject: [PATCH 02/17] Fix NEWS.d formatting --- .../next/Library/2024-09-27-19-50-30.gh-issue-90562.HeL_JA.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Misc/NEWS.d/next/Library/2024-09-27-19-50-30.gh-issue-90562.HeL_JA.rst b/Misc/NEWS.d/next/Library/2024-09-27-19-50-30.gh-issue-90562.HeL_JA.rst index 56f2e1109ce5f5..b0bbdd406f63c7 100644 --- a/Misc/NEWS.d/next/Library/2024-09-27-19-50-30.gh-issue-90562.HeL_JA.rst +++ b/Misc/NEWS.d/next/Library/2024-09-27-19-50-30.gh-issue-90562.HeL_JA.rst @@ -1,2 +1,2 @@ Modify dataclasses to enable zero argument support for ``super()`` when ``slots=True`` is -specified and custom descriptor is used or `property` function is wrapped. +specified and custom descriptor is used or ``property`` function is wrapped. From 3654f51b8295e88b062c691a88b130fb6b90052a Mon Sep 17 00:00:00 2001 From: Arseny Boykov <36469655+Bobronium@users.noreply.github.com> Date: Sat, 28 Sep 2024 00:42:26 +0200 Subject: [PATCH 03/17] Add explanation for _depth use and tests --- Lib/dataclasses.py | 3 ++ Lib/test/test_dataclasses/__init__.py | 68 +++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 5d2088d7cf9b6c..215c91858465aa 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1245,6 +1245,9 @@ def _find_inner_functions(obj, _seen=None, _depth=0): _seen.add(id(obj)) _depth += 1 + # we don't want to dive too deep in an object in search for a function. + # usually it will be stored on outer levels of nesting, but in just + # for sake of special cases when if _depth > 2: return None diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 8611cb3dbf7f87..052e3b97df4aa6 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -5072,6 +5072,74 @@ def foo(cls): self.assertEqual(A().foo, "bar") + def test_custom_nested_descriptor(self): + class CustomFunctionWrapper: + def __init__(self, f): + self._f = f + + def __call__(self, *args, **kwargs): + return self._f(*args, **kwargs) + + class CustomDescriptor: + def __init__(self, f): + self._wrapper = CustomFunctionWrapper(f) + + def __get__(self, instance, owner): + return self._wrapper(instance) + + class B: + def foo(self): + return "bar" + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(cls): + return super().foo() + + self.assertEqual(A().foo, "bar") + + def test_custom_too_nested_descriptor(self): + class UnnecessaryNestedWrapper: + def __init__(self, wrapper): + self._wrapper = wrapper + + def __call__(self, *args, **kwargs): + return self._wrapper(*args, **kwargs) + + class CustomFunctionWrapper: + def __init__(self, f): + self._f = f + + def __call__(self, *args, **kwargs): + return self._f(*args, **kwargs) + + class CustomDescriptor: + def __init__(self, f): + self._wrapper = UnnecessaryNestedWrapper(CustomFunctionWrapper(f)) + + def __get__(self, instance, owner): + return self._wrapper(instance) + + class B: + def foo(self): + return "bar" + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(cls): + return super().foo() + + with self.assertRaises(TypeError) as context: + A().foo + + expected_error_message = ( + 'super(type, obj): obj (instance of A) is not ' + 'an instance or subtype of type (A).' + ) + self.assertEqual(context.exception.args, (expected_error_message,)) + def test_remembered_class(self): # Apply the dataclass decorator manually (not when the class # is created), so that we can keep a reference to the From 19297bc9ee6d8cbfa2c29643f60a544b9155afa6 Mon Sep 17 00:00:00 2001 From: Arseny Boykov <36469655+Bobronium@users.noreply.github.com> Date: Sat, 28 Sep 2024 00:42:42 +0200 Subject: [PATCH 04/17] Fix incorrect variable reference --- Lib/dataclasses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 215c91858465aa..ea276f761b0f23 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1257,8 +1257,8 @@ def _find_inner_functions(obj, _seen=None, _depth=0): value = getattr(obj, attr, None) if value is None: continue - if isinstance(obj, types.FunctionType): - yield obj + if isinstance(value, types.FunctionType): + yield value return yield from _find_inner_functions(value, _seen, _depth) From 2606480ffa4ac986cf532d1a1778281d4f1d8446 Mon Sep 17 00:00:00 2001 From: Arseny Boykov <36469655+Bobronium@users.noreply.github.com> Date: Sat, 28 Sep 2024 00:55:01 +0200 Subject: [PATCH 05/17] Correctly unwrap descriptor function --- Lib/dataclasses.py | 4 +--- Lib/test/test_dataclasses/__init__.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index ea276f761b0f23..8413c7c09438ec 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1251,14 +1251,12 @@ def _find_inner_functions(obj, _seen=None, _depth=0): if _depth > 2: return None - obj = inspect.unwrap(obj) - for attr in dir(obj): value = getattr(obj, attr, None) if value is None: continue if isinstance(value, types.FunctionType): - yield value + yield inspect.unwrap(value) return yield from _find_inner_functions(value, _seen, _depth) diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 052e3b97df4aa6..45f291d7aa6ad7 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -13,6 +13,7 @@ import weakref import traceback import unittest +from functools import update_wrapper from unittest.mock import Mock from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol, DefaultDict from typing import get_type_hints @@ -5072,6 +5073,26 @@ def foo(cls): self.assertEqual(A().foo, "bar") + def test_custom_descriptor_wrapped(self): + class CustomDescriptor: + def __init__(self, f): + self._f = update_wrapper(lambda *args, **kwargs: f(*args, **kwargs), f) + + def __get__(self, instance, owner): + return self._f(instance) + + class B: + def foo(self): + return "bar" + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(cls): + return super().foo() + + self.assertEqual(A().foo, "bar") + def test_custom_nested_descriptor(self): class CustomFunctionWrapper: def __init__(self, f): From 42f9dc3f8f5e5ce1502d3e7b6059f34908249d79 Mon Sep 17 00:00:00 2001 From: Arseny Boykov <36469655+Bobronium@users.noreply.github.com> Date: Sat, 28 Sep 2024 00:55:48 +0200 Subject: [PATCH 06/17] Add test with using partial as custom wrapper --- Lib/test/test_dataclasses/__init__.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 45f291d7aa6ad7..84f049d1f67ff3 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -13,6 +13,7 @@ import weakref import traceback import unittest +from functools import partial from functools import update_wrapper from unittest.mock import Mock from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol, DefaultDict @@ -5120,6 +5121,26 @@ def foo(cls): self.assertEqual(A().foo, "bar") + def test_custom_nested_descriptor_with_partial(self): + class CustomDescriptor: + def __init__(self, f): + self._wrapper = partial(f, value="bar") + + def __get__(self, instance, owner): + return self._wrapper(instance) + + class B: + def foo(self, value): + return value + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(self, value): + return super().foo(value) + + self.assertEqual(A().foo, "bar") + def test_custom_too_nested_descriptor(self): class UnnecessaryNestedWrapper: def __init__(self, wrapper): From 53c18d114b34c847247013a1e3e6f2b54cdbf707 Mon Sep 17 00:00:00 2001 From: Arseny Boykov <36469655+Bobronium@users.noreply.github.com> Date: Sat, 28 Sep 2024 01:03:03 +0200 Subject: [PATCH 07/17] Rephrase comment explaining _depth argument --- Lib/dataclasses.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 8413c7c09438ec..7f6e55112b44b6 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1245,9 +1245,10 @@ def _find_inner_functions(obj, _seen=None, _depth=0): _seen.add(id(obj)) _depth += 1 - # we don't want to dive too deep in an object in search for a function. - # usually it will be stored on outer levels of nesting, but in just - # for sake of special cases when + # normally just inspection of a descriptor object itself should be enough, + # and we should encounter the function as its attribute, + # but in case function was wrapped (e.g. functools.partition was used), + # we want to dive at least one level deeper if _depth > 2: return None From 96d5315531573932f9fa99a722c14d1d2d952121 Mon Sep 17 00:00:00 2001 From: Arseny Boykov <36469655+Bobronium@users.noreply.github.com> Date: Sat, 28 Sep 2024 01:18:02 +0200 Subject: [PATCH 08/17] Add clarification for function search priority --- Lib/dataclasses.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 7f6e55112b44b6..826e2c97903117 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1337,8 +1337,9 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields): # (the newly created one, which we're returning) and not the # original class. We can break out of this loop as soon as we # make an update, since all closures for a class will share a - # given cell. First we try to find a pure function/properties, - # and then fallback to inspecting custom descriptors. + # given cell. First we try to find a pure function or a property, + # and then fallback to inspecting custom descriptors + # if no pure function or property is found. custom_descriptors_to_check = [] for member in newcls.__dict__.values(): @@ -1360,9 +1361,9 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields): elif hasattr(member, "__get__") and not inspect.ismemberdescriptor( member ): - # we don't want to inspect custom descriptors just yet + # We don't want to inspect custom descriptors just yet # there's still a chance we'll encounter a pure function - # or a property + # or a property and won't have to use slower recursive search. custom_descriptors_to_check.append(member) else: # now let's ensure custom descriptors won't be left out From 22eeb8f762e1819762843e5f5b725ae89b2af26f Mon Sep 17 00:00:00 2001 From: Arseny Boykov <36469655+Bobronium@users.noreply.github.com> Date: Sat, 28 Sep 2024 01:21:52 +0200 Subject: [PATCH 09/17] Ues consistent grammar throughout comments --- Lib/dataclasses.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 826e2c97903117..5587acffcf28b3 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1245,10 +1245,10 @@ def _find_inner_functions(obj, _seen=None, _depth=0): _seen.add(id(obj)) _depth += 1 - # normally just inspection of a descriptor object itself should be enough, + # Normally just inspection of a descriptor object itself should be enough, # and we should encounter the function as its attribute, # but in case function was wrapped (e.g. functools.partition was used), - # we want to dive at least one level deeper + # we want to dive at least one level deeper. if _depth > 2: return None @@ -1353,8 +1353,8 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields): for f in member.fget, member.fset, member.fdel: if f is None: continue - # unwrap once more in case function - # was wrapped before it became property + # Unwrap once more in case function + # was wrapped before it became property. f = inspect.unwrap(f) if _update_func_cell_for__class__(f, cls, newcls): break @@ -1366,7 +1366,7 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields): # or a property and won't have to use slower recursive search. custom_descriptors_to_check.append(member) else: - # now let's ensure custom descriptors won't be left out + # Now let's ensure custom descriptors won't be left out. for descriptor in custom_descriptors_to_check: for f in _find_inner_functions(descriptor): if _update_func_cell_for__class__(f, cls, newcls): From e91c28bcafcd2efc26be3c1822b40920b71cc6c7 Mon Sep 17 00:00:00 2001 From: Arseny Boykov <36469655+Bobronium@users.noreply.github.com> Date: Sat, 28 Sep 2024 01:35:46 +0200 Subject: [PATCH 10/17] Fix breaking out from outer loop when handling property --- Lib/dataclasses.py | 3 ++ Lib/test/test_dataclasses/__init__.py | 46 +++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 5587acffcf28b3..4883cff2cd4f46 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1358,6 +1358,9 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields): f = inspect.unwrap(f) if _update_func_cell_for__class__(f, cls, newcls): break + else: + continue + break elif hasattr(member, "__get__") and not inspect.ismemberdescriptor( member ): diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 84f049d1f67ff3..09aeb0077ce4cc 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -5141,6 +5141,52 @@ def foo(self, value): self.assertEqual(A().foo, "bar") + def test_pure_functions_preferred_to_custom_descriptors(self): + class CustomDescriptor: + def __init__(self, f): + self._wrapper = partial(f, value="bar") + + def __get__(self, instance, owner): + return self._wrapper(instance) + + def __dir__(self): + raise RuntimeError("Never should be accessed") + + class B: + def foo(self, value): + return value + + with self.assertRaises(RuntimeError) as context: + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(self, value): ... + + self.assertEqual(context.exception.args, ("Never should be accessed",)) + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(self, value): + return super().foo(value) + + @property + def bar(self): + return super() + + self.assertEqual(A().foo, "bar") + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(self, value): + return super().foo(value) + + def bar(self): + return super() + + self.assertEqual(A().foo, "bar") + def test_custom_too_nested_descriptor(self): class UnnecessaryNestedWrapper: def __init__(self, wrapper): From 4941a3fdb2d1f23c55b8d45968a78dc2e074ebd1 Mon Sep 17 00:00:00 2001 From: Arseny Boykov <36469655+Bobronium@users.noreply.github.com> Date: Sat, 28 Sep 2024 01:39:39 +0200 Subject: [PATCH 11/17] Use generator expression with any() instead of classic loop --- Lib/dataclasses.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 4883cff2cd4f46..65891c0c00e8eb 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1349,17 +1349,15 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields): if isinstance(member, types.FunctionType): if _update_func_cell_for__class__(member, cls, newcls): break - elif isinstance(member, property): - for f in member.fget, member.fset, member.fdel: - if f is None: - continue + elif isinstance(member, property) and ( + any( # Unwrap once more in case function # was wrapped before it became property. - f = inspect.unwrap(f) - if _update_func_cell_for__class__(f, cls, newcls): - break - else: - continue + _update_func_cell_for__class__(inspect.unwrap(f), cls, newcls) + for f in (member.fget, member.fset, member.fdel) + if f is not None + ) + ): break elif hasattr(member, "__get__") and not inspect.ismemberdescriptor( member From d33070f21ef0d393e57d88ce246bc3d9b1c493cf Mon Sep 17 00:00:00 2001 From: Arseny Boykov <36469655+Bobronium@users.noreply.github.com> Date: Sat, 28 Sep 2024 13:03:14 +0200 Subject: [PATCH 12/17] Apply suggestions from code review Co-authored-by: sobolevn --- Lib/dataclasses.py | 4 ++-- Lib/test/test_dataclasses/__init__.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 65891c0c00e8eb..ed68a61ede1582 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1245,9 +1245,9 @@ def _find_inner_functions(obj, _seen=None, _depth=0): _seen.add(id(obj)) _depth += 1 - # Normally just inspection of a descriptor object itself should be enough, + # Normally just an inspection of a descriptor object itself should be enough, # and we should encounter the function as its attribute, - # but in case function was wrapped (e.g. functools.partition was used), + # but in case function was wrapped (e.g. functools.partial was used), # we want to dive at least one level deeper. if _depth > 2: return None diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 09aeb0077ce4cc..889a669f94780d 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -13,8 +13,7 @@ import weakref import traceback import unittest -from functools import partial -from functools import update_wrapper +from functools import partial, update_wrapper from unittest.mock import Mock from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol, DefaultDict from typing import get_type_hints From c26e978cce86e079a6c0fdfb4ab615952536955f Mon Sep 17 00:00:00 2001 From: Arseny Boykov <36469655+Bobronium@users.noreply.github.com> Date: Sat, 28 Sep 2024 13:09:13 +0200 Subject: [PATCH 13/17] Remove leading underscores from seen and depth args --- Lib/dataclasses.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index ed68a61ede1582..d9832f42b6cdbe 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1237,19 +1237,19 @@ def _update_func_cell_for__class__(f, oldcls, newcls): return False -def _find_inner_functions(obj, _seen=None, _depth=0): - if _seen is None: - _seen = set() - if id(obj) in _seen: +def _find_inner_functions(obj, seen=None, depth=0): + if seen is None: + seen = set() + if id(obj) in seen: return None - _seen.add(id(obj)) + seen.add(id(obj)) - _depth += 1 + depth += 1 # Normally just an inspection of a descriptor object itself should be enough, # and we should encounter the function as its attribute, # but in case function was wrapped (e.g. functools.partial was used), # we want to dive at least one level deeper. - if _depth > 2: + if depth > 2: return None for attr in dir(obj): @@ -1259,7 +1259,7 @@ def _find_inner_functions(obj, _seen=None, _depth=0): if isinstance(value, types.FunctionType): yield inspect.unwrap(value) return - yield from _find_inner_functions(value, _seen, _depth) + yield from _find_inner_functions(value, seen, depth) def _create_slots(defined_fields, inherited_slots, field_names, weakref_slot): From d0173d8e39bf6d2b91995aba774bf744bb7ce193 Mon Sep 17 00:00:00 2001 From: Arseny Boykov <36469655+Bobronium@users.noreply.github.com> Date: Sat, 28 Sep 2024 22:37:06 +0200 Subject: [PATCH 14/17] Prevent user-defined code execution during attribute scanning --- Lib/dataclasses.py | 24 +++++++++++--- Lib/test/test_dataclasses/__init__.py | 46 --------------------------- 2 files changed, 20 insertions(+), 50 deletions(-) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index d9832f42b6cdbe..b381bd07f1b7d6 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1237,6 +1237,25 @@ def _update_func_cell_for__class__(f, oldcls, newcls): return False +def _safe_get_attributes(obj): + # we should avoid triggering any user-defined code + # when inspecting attributes if possible + + # look for __slots__ descriptors + type_dict = object.__getattribute__(type(obj), "__dict__") + for value in type_dict.values(): + if isinstance(value, types.MemberDescriptorType): + yield value.__get__(obj) + + instance_dict_descriptor = type_dict.get("__dict__", None) + if not isinstance(instance_dict_descriptor, types.GetSetDescriptorType): + # __dict__ is either not present, or redefined by user + # as custom descriptor, either way, we're done here + return + + yield from instance_dict_descriptor.__get__(obj).values() + + def _find_inner_functions(obj, seen=None, depth=0): if seen is None: seen = set() @@ -1252,10 +1271,7 @@ def _find_inner_functions(obj, seen=None, depth=0): if depth > 2: return None - for attr in dir(obj): - value = getattr(obj, attr, None) - if value is None: - continue + for value in _safe_get_attributes(obj): if isinstance(value, types.FunctionType): yield inspect.unwrap(value) return diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 889a669f94780d..57622f77029668 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -5140,52 +5140,6 @@ def foo(self, value): self.assertEqual(A().foo, "bar") - def test_pure_functions_preferred_to_custom_descriptors(self): - class CustomDescriptor: - def __init__(self, f): - self._wrapper = partial(f, value="bar") - - def __get__(self, instance, owner): - return self._wrapper(instance) - - def __dir__(self): - raise RuntimeError("Never should be accessed") - - class B: - def foo(self, value): - return value - - with self.assertRaises(RuntimeError) as context: - @dataclass(slots=True) - class A(B): - @CustomDescriptor - def foo(self, value): ... - - self.assertEqual(context.exception.args, ("Never should be accessed",)) - - @dataclass(slots=True) - class A(B): - @CustomDescriptor - def foo(self, value): - return super().foo(value) - - @property - def bar(self): - return super() - - self.assertEqual(A().foo, "bar") - - @dataclass(slots=True) - class A(B): - @CustomDescriptor - def foo(self, value): - return super().foo(value) - - def bar(self): - return super() - - self.assertEqual(A().foo, "bar") - def test_custom_too_nested_descriptor(self): class UnnecessaryNestedWrapper: def __init__(self, wrapper): From c68b4cc08f3048703c92422071afa8977eae0a47 Mon Sep 17 00:00:00 2001 From: Arseny Boykov Date: Sat, 5 Oct 2024 14:41:17 +0200 Subject: [PATCH 15/17] Use inspect.getattr_static instead of _safe_get_attributes This uses builtins.dir, which will trigger custom __getattibute__ if any, and will trigger __get__ on __dict__ descriptor. --- Lib/dataclasses.py | 39 ++++++++++++--------------- Lib/test/test_dataclasses/__init__.py | 39 +++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 22 deletions(-) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index b381bd07f1b7d6..bf5f095b3292f0 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1237,25 +1237,6 @@ def _update_func_cell_for__class__(f, oldcls, newcls): return False -def _safe_get_attributes(obj): - # we should avoid triggering any user-defined code - # when inspecting attributes if possible - - # look for __slots__ descriptors - type_dict = object.__getattribute__(type(obj), "__dict__") - for value in type_dict.values(): - if isinstance(value, types.MemberDescriptorType): - yield value.__get__(obj) - - instance_dict_descriptor = type_dict.get("__dict__", None) - if not isinstance(instance_dict_descriptor, types.GetSetDescriptorType): - # __dict__ is either not present, or redefined by user - # as custom descriptor, either way, we're done here - return - - yield from instance_dict_descriptor.__get__(obj).values() - - def _find_inner_functions(obj, seen=None, depth=0): if seen is None: seen = set() @@ -1271,11 +1252,25 @@ def _find_inner_functions(obj, seen=None, depth=0): if depth > 2: return None - for value in _safe_get_attributes(obj): + for attribute in dir(obj): + try: + value = inspect.getattr_static(obj, attribute) + except AttributeError: + continue + builtin_value = inspect.getattr_static(object, attribute, None) + if value is builtin_value: + # don't waste time on builtin objects + continue + if ( + # isinstance() would trigger `value.__getattribute__("__class__")` + type(value) is types.MemberDescriptorType + and type not in inspect._static_getmro(type(obj)) + ): + value = value.__get__(obj) if isinstance(value, types.FunctionType): yield inspect.unwrap(value) - return - yield from _find_inner_functions(value, seen, depth) + else: + yield from _find_inner_functions(value, seen, depth) def _create_slots(defined_fields, inherited_slots, field_names, weakref_slot): diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 57622f77029668..3cb2dc8ce5e0f5 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -5181,6 +5181,45 @@ def foo(cls): ) self.assertEqual(context.exception.args, (expected_error_message,)) + def test_user_defined_code_execution(self): + class CustomDescriptor: + def __init__(self, f): + self._wrapper = partial(f, value="bar") + + def __get__(self, instance, owner): + return object.__getattribute__(self, "_wrapper")(instance) + + def __getattribute__(self, name): + if name in { + # these are the bare minimum for the feature to work + "__class__", # accessed on `isinstance(value, Field)` + "__wrapped__", # accessed by unwrap + "__get__", # is required for the descriptor protocol + "__dict__", # is accessed by dir() to work + }: + return object.__getattribute__(self, name) + raise RuntimeError(f"Never should be accessed: {name}") + + class B: + def foo(self, value): + return value + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(self, value): + return super().foo(value) + + self.assertEqual(A().foo, "bar") + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(self, value): + return super().foo(value) + + self.assertEqual(A().foo, "bar") + def test_remembered_class(self): # Apply the dataclass decorator manually (not when the class # is created), so that we can keep a reference to the From 7d679f226d74975a2c4498873a52fe0c1d5e0818 Mon Sep 17 00:00:00 2001 From: Arseny Boykov Date: Sat, 5 Oct 2024 14:51:01 +0200 Subject: [PATCH 16/17] Put back the comment in _update_func_cell_for__class__ --- Lib/dataclasses.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index bf5f095b3292f0..8c8cd49218c21f 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1222,6 +1222,7 @@ def _get_slots(cls): def _update_func_cell_for__class__(f, oldcls, newcls): + # Returns True if we update a cell, else False. try: idx = f.__code__.co_freevars.index("__class__") except ValueError: From 5bfea4d299566099b780f0bd17e616b157d41046 Mon Sep 17 00:00:00 2001 From: Arseny Boykov Date: Sat, 5 Oct 2024 15:16:54 +0200 Subject: [PATCH 17/17] Rely on inspect.getmembers_static() --- Lib/dataclasses.py | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 8c8cd49218c21f..9aa84d0fa15718 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1238,6 +1238,22 @@ def _update_func_cell_for__class__(f, oldcls, newcls): return False +_object_members_values = { + value for name, value in + ( + *inspect.getmembers_static(object), + *inspect.getmembers_static(object()) + ) +} + + +def _is_not_object_member(v): + try: + return v not in _object_members_values + except TypeError: + return True + + def _find_inner_functions(obj, seen=None, depth=0): if seen is None: seen = set() @@ -1253,22 +1269,14 @@ def _find_inner_functions(obj, seen=None, depth=0): if depth > 2: return None - for attribute in dir(obj): - try: - value = inspect.getattr_static(obj, attribute) - except AttributeError: - continue - builtin_value = inspect.getattr_static(object, attribute, None) - if value is builtin_value: - # don't waste time on builtin objects - continue - if ( - # isinstance() would trigger `value.__getattribute__("__class__")` - type(value) is types.MemberDescriptorType - and type not in inspect._static_getmro(type(obj)) - ): + obj_is_type_instance = type in inspect._static_getmro(type(obj)) + for _, value in inspect.getmembers_static(obj, _is_not_object_member): + value_type = type(value) + if value_type is types.MemberDescriptorType and not obj_is_type_instance: value = value.__get__(obj) - if isinstance(value, types.FunctionType): + value_type = type(value) + + if value_type is types.FunctionType: yield inspect.unwrap(value) else: yield from _find_inner_functions(value, seen, depth)