diff --git a/doc/api/next_api_changes/behavior/28375-MP.rst b/doc/api/next_api_changes/behavior/28375-MP.rst new file mode 100644 index 000000000000..75d7f7cf5030 --- /dev/null +++ b/doc/api/next_api_changes/behavior/28375-MP.rst @@ -0,0 +1,5 @@ +``transforms.AffineDeltaTransform`` updates correctly on axis limit changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Before this change, transform sub-graphs with ``AffineDeltaTransform`` did not update correctly. +This PR ensures that changes to the child transform are passed through correctly. diff --git a/lib/matplotlib/tests/test_transforms.py b/lib/matplotlib/tests/test_transforms.py index 959814de82db..fb8b7d74bc94 100644 --- a/lib/matplotlib/tests/test_transforms.py +++ b/lib/matplotlib/tests/test_transforms.py @@ -341,6 +341,31 @@ def test_deepcopy(self): assert_array_equal(s.get_matrix(), a.get_matrix()) +class TestAffineDeltaTransform: + def test_invalidate(self): + before = np.array([[1.0, 4.0, 0.0], + [5.0, 1.0, 0.0], + [0.0, 0.0, 1.0]]) + after = np.array([[1.0, 3.0, 0.0], + [5.0, 1.0, 0.0], + [0.0, 0.0, 1.0]]) + + # Translation and skew present + base = mtransforms.Affine2D.from_values(1, 5, 4, 1, 2, 3) + t = mtransforms.AffineDeltaTransform(base) + assert_array_equal(t.get_matrix(), before) + + # Mess with the internal structure of `base` without invalidating + # This should not affect this transform because it's a passthrough: + # it's always invalid + base.get_matrix()[0, 1:] = 3 + assert_array_equal(t.get_matrix(), after) + + # Invalidate the base + base.invalidate() + assert_array_equal(t.get_matrix(), after) + + def test_non_affine_caching(): class AssertingNonAffineTransform(mtransforms.Transform): """ diff --git a/lib/matplotlib/transforms.py b/lib/matplotlib/transforms.py index 5003e2113930..e6e7a6bca637 100644 --- a/lib/matplotlib/transforms.py +++ b/lib/matplotlib/transforms.py @@ -2711,9 +2711,12 @@ class AffineDeltaTransform(Affine2DBase): This class is experimental as of 3.3, and the API may change. """ + pass_through = True + def __init__(self, transform, **kwargs): super().__init__(**kwargs) self._base_transform = transform + self.set_children(transform) __str__ = _make_str_method("_base_transform")