diff --git a/lib/mpl_toolkits/axisartist/floating_axes.py b/lib/mpl_toolkits/axisartist/floating_axes.py index 84690ff3259b..bc171e1f9d24 100644 --- a/lib/mpl_toolkits/axisartist/floating_axes.py +++ b/lib/mpl_toolkits/axisartist/floating_axes.py @@ -8,9 +8,9 @@ import numpy as np from matplotlib import _api, cbook +import matplotlib.axes as maxes import matplotlib.patches as mpatches from matplotlib.path import Path -import matplotlib.axes as maxes from mpl_toolkits.axes_grid1.parasite_axes import host_axes_class_factory @@ -311,13 +311,9 @@ def get_boundary(self): class FloatingAxesBase: - def __init__(self, *args, **kwargs): - grid_helper = kwargs.get("grid_helper", None) - if grid_helper is None: - raise ValueError("FloatingAxes requires grid_helper argument") - if not hasattr(grid_helper, "get_boundary"): - raise ValueError("grid_helper must implement get_boundary method") - super().__init__(*args, **kwargs) + def __init__(self, *args, grid_helper, **kwargs): + _api.check_isinstance(GridHelperCurveLinear, grid_helper=grid_helper) + super().__init__(*args, grid_helper=grid_helper, **kwargs) self.set_aspect(1.) self.adjust_axes_lim()