diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index e5f7144c0..7c7aad35c 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -130,55 +130,55 @@ jobs: contents: write steps: - name: Create GitHub release - uses: softprops/action-gh-release@v2 + uses: softprops/action-gh-release@v3 with: generate_release_notes: true - publish-zenodo: - name: Publish Zenodo release - needs: publish-github-release - runs-on: ubuntu-latest - if: github.event_name == 'push' - permissions: - contents: read - env: - ZENODO_ACCESS_TOKEN: ${{ secrets.ZENODO_ACCESS_TOKEN }} - steps: - - uses: actions/checkout@v6 - with: - fetch-depth: 0 - - - uses: actions/setup-python@v6 - with: - python-version: "3.12" - - - name: Install release tooling - run: | - python -m pip install --upgrade pip - python -m pip install PyYAML - shell: bash - - - name: Download artifacts - uses: actions/download-artifact@v8 - with: - name: dist-${{ github.sha }}-${{ github.run_id }}-${{ github.run_number }} - path: dist - - - name: Generate release citation metadata - run: | - python tools/release/sync_citation.py \ - --tag "${GITHUB_REF_NAME}" \ - --output "${RUNNER_TEMP}/CITATION.cff" - shell: bash - - - name: Check tree stayed clean - run: | - git diff --quiet || (git status --short && git diff && exit 1) - shell: bash - - - name: Publish to Zenodo - run: | - python tools/release/publish_zenodo.py \ - --dist-dir dist \ - --citation-file "${RUNNER_TEMP}/CITATION.cff" - shell: bash + # publish-zenodo: + # name: Publish Zenodo release + # needs: publish-github-release + # runs-on: ubuntu-latest + # if: github.event_name == 'push' + # permissions: + # contents: read + # env: + # ZENODO_ACCESS_TOKEN: ${{ secrets.ZENODO_ACCESS_TOKEN }} + # steps: + # - uses: actions/checkout@v6 + # with: + # fetch-depth: 0 + # + # - uses: actions/setup-python@v6 + # with: + # python-version: "3.12" + # + # - name: Install release tooling + # run: | + # python -m pip install --upgrade pip + # python -m pip install PyYAML + # shell: bash + # + # - name: Download artifacts + # uses: actions/download-artifact@v8 + # with: + # name: dist-${{ github.sha }}-${{ github.run_id }}-${{ github.run_number }} + # path: dist + # + # - name: Generate release citation metadata + # run: | + # python tools/release/sync_citation.py \ + # --tag "${GITHUB_REF_NAME}" \ + # --output "${RUNNER_TEMP}/CITATION.cff" + # shell: bash + # + # - name: Check tree stayed clean + # run: | + # git diff --quiet || (git status --short && git diff && exit 1) + # shell: bash + # + # - name: Publish to Zenodo + # run: | + # python tools/release/publish_zenodo.py \ + # --dist-dir dist \ + # --citation-file "${RUNNER_TEMP}/CITATION.cff" + # shell: bash diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a5d2db1f9..8daf1d1a2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,6 +11,6 @@ ci: repos: - repo: https://github.com/psf/black-pre-commit-mirror - rev: 26.3.1 + rev: 26.5.1 hooks: - id: black diff --git a/docs/_static/custom.js b/docs/_static/custom.js index 4bd28d2d2..b57f5d3fc 100644 --- a/docs/_static/custom.js +++ b/docs/_static/custom.js @@ -329,7 +329,134 @@ function initShibuyaRightToc() { syncRightTocCodeButtons(localtoc); } +const upltApiSearchGenericTerms = new Set([ + "api", + "apis", + "attribute", + "attributes", + "class", + "classes", + "doc", + "docs", + "documentation", + "function", + "functions", + "method", + "methods", + "object", + "objects", + "reference", + "references", +]); + +function normalizeApiSearchTerm(term) { + return String(term || "") + .toLowerCase() + .replace(/\(\)$/, "") + .trim(); +} + +function isGenericApiSearchTerm(term) { + return upltApiSearchGenericTerms.has(normalizeApiSearchTerm(term)); +} + +function getApiSearchTerms(terms) { + if (terms instanceof Set) { + return Array.from(terms); + } + return Array.from(terms || []); +} + +function apiSearchResultMatchesQueryTerm(title, anchor, terms) { + const haystack = `${title || ""} ${anchor || ""}`.toLowerCase(); + const leaf = haystack.split("#").pop().split(".").pop(); + return getApiSearchTerms(terms).some((term) => { + const normalized = normalizeApiSearchTerm(term); + if (!normalized || isGenericApiSearchTerm(normalized)) return false; + return ( + leaf === normalized || + leaf.includes(normalized) || + haystack.includes("." + normalized) + ); + }); +} + +function initApiSearchScoring() { + if (typeof Search === "undefined" || typeof Scorer === "undefined") return; + if (Search.upltApiSearchScoring === "1") return; + + const previousParseQuery = Search._parseQuery; + if (typeof previousParseQuery === "function") { + Search._parseQuery = function (query) { + const parsed = previousParseQuery.call(this, query); + const queryTerms = new Set( + getApiSearchTerms(parsed && parsed[4]) + .map(normalizeApiSearchTerm) + .filter(Boolean), + ); + Search.upltQueryTerms = queryTerms; + Search.upltApiLikeQuery = + /[.()]/.test(query || "") || + getApiSearchTerms(queryTerms).some(isGenericApiSearchTerm); + return parsed; + }; + } + + const previousObjectSearch = Search.performObjectSearch; + if (typeof previousObjectSearch === "function") { + Search.performObjectSearch = function (object, objectTerms) { + const normalizedObject = normalizeApiSearchTerm(object); + const filteredTerms = new Set( + getApiSearchTerms(objectTerms) + .map(normalizeApiSearchTerm) + .filter((term) => term && !isGenericApiSearchTerm(term)), + ); + if (normalizedObject && !isGenericApiSearchTerm(normalizedObject)) { + filteredTerms.add(normalizedObject); + } + return previousObjectSearch.call(this, object, filteredTerms); + }; + } + + const previousScore = Scorer.score; + Scorer.score = function (result) { + let score = + typeof previousScore === "function" ? previousScore(result) : result[4]; + if (!Number.isFinite(score)) { + score = Number.isFinite(result[4]) ? result[4] : 0; + } + + const [docname, title, anchor, descr, _baseScore, _filename, kind] = result; + const isApiReference = String(docname || "").startsWith("api/"); + const isApiLikeQuery = !!Search.upltApiLikeQuery; + const queryTerms = Search.upltQueryTerms || new Set(); + + if (isApiReference && kind === "object") { + score += 24; + if (isApiLikeQuery) score += 16; + if (apiSearchResultMatchesQueryTerm(title, anchor, queryTerms)) { + score += 12; + } + if ( + queryTerms.has("function") && + String(descr || "").toLowerCase().includes("python function") + ) { + score += 4; + } + } else if (isApiReference && isApiLikeQuery) { + score += kind === "title" || kind === "index" ? 12 : 8; + } else if (!isApiReference && isApiLikeQuery) { + score -= 4; + } + + return score; + }; + + Search.upltApiSearchScoring = "1"; +} + document.addEventListener("DOMContentLoaded", function () { + initApiSearchScoring(); initScrollChromeFade(); if (document.querySelector(".sphx-glr-thumbcontainer")) { diff --git a/docs/colorbars_legends.py b/docs/colorbars_legends.py index 4044a9200..5b0582622 100644 --- a/docs/colorbars_legends.py +++ b/docs/colorbars_legends.py @@ -477,13 +477,19 @@ # # Legends usually annotate artists already drawn on an axes, but sometimes you need # standalone semantic keys (categories, size scales, color levels, or geometry types). -# UltraPlot provides helper methods that build these entries directly: +# UltraPlot provides helper methods that build these entries directly on both +# axes and figures: # # * :meth:`~ultraplot.axes.Axes.entrylegend` # * :meth:`~ultraplot.axes.Axes.catlegend` # * :meth:`~ultraplot.axes.Axes.sizelegend` # * :meth:`~ultraplot.axes.Axes.numlegend` # * :meth:`~ultraplot.axes.Axes.geolegend` +# * :meth:`~ultraplot.figure.Figure.entrylegend` +# * :meth:`~ultraplot.figure.Figure.catlegend` +# * :meth:`~ultraplot.figure.Figure.sizelegend` +# * :meth:`~ultraplot.figure.Figure.numlegend` +# * :meth:`~ultraplot.figure.Figure.geolegend` # # These helpers are useful whenever the legend should describe an encoding rather than # mirror artists that already happen to be drawn. In practice there are two distinct @@ -513,7 +519,8 @@ # # The helpers are intentionally composable. Each one accepts ``add=False`` and returns # ``(handles, labels)`` so you can merge semantic sections and pass the result through -# :meth:`~ultraplot.axes.Axes.legend` yourself. +# :meth:`~ultraplot.axes.Axes.legend` or :meth:`~ultraplot.figure.Figure.legend` +# yourself. # # .. code-block:: python # @@ -568,6 +575,27 @@ # # .. code-block:: python # +# # Add semantic legends around an entire subplot group. +# fig, axs = uplt.subplots(ncols=2) +# fig.catlegend( +# ["Control", "Treatment"], +# colors={"Control": "blue7", "Treatment": "red7"}, +# markers={"Control": "o", "Treatment": "^"}, +# ref=axs, +# loc="b", +# title="Group", +# ) +# fig.sizelegend( +# [10, 50, 200], +# labels=["Small", "Medium", "Large"], +# color="gray6", +# ref=axs, +# loc="r", +# title="Population", +# ) +# +# .. code-block:: python +# # # Compose multiple semantic helpers into one legend. # size_handles, size_labels = ax.sizelegend( # [10, 50, 200], @@ -685,6 +713,32 @@ ax.axis("off") +# %% +fig, axs = uplt.subplots(ncols=2, refwidth=2.8, share=False) +axs[0].scatter([0, 1, 2], [3, 1, 2], c=[0.2, 0.5, 0.8], s=[40, 120, 260]) +axs[1].scatter([0, 1, 2], [2, 3, 1], c=[0.8, 0.4, 0.1], s=[60, 90, 220]) +axs.format(title="Figure semantic legend helpers", grid=False) + +fig.catlegend( + ["Control", "Treatment"], + colors={"Control": "blue7", "Treatment": "red7"}, + markers={"Control": "o", "Treatment": "^"}, + ref=axs, + loc="bottom", + title="Group", + frameon=False, +) +fig.sizelegend( + [40, 120, 260], + labels=["Small", "Medium", "Large"], + color="gray6", + ref=axs, + loc="right", + title="Size scale", + frameon=False, +) + + # %% [raw] raw_mimetype="text/restructuredtext" # .. _ug_guides_decouple: # diff --git a/docs/contributing.rst b/docs/contributing.rst index 12d74cd19..6ccc4cc9c 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -241,6 +241,39 @@ Note that you can create the pull request before you're finished with your feature addition or bug fix. The PR will update as you add more commits. UltraPlot developers and contributors can then review your code and offer suggestions. +.. _contrib_ai: + +AI policy +========= + +UltraPlot welcomes contributions from developers at all skill levels, including +those who use AI tools as part of their workflow. To keep contributions +meaningful and to help new contributors genuinely learn the codebase, we ask +that you follow these guidelines. + +**Good first issues must be written by humans.** +Issues labelled *good first issue* are intentionally kept for people who want +to get familiar with the backend. These issues should be scoped, described, and +solved by a human — not generated or resolved wholesale by an AI assistant. +Working through them yourself is how you build the mental model of the code +that makes future contributions easier. + +**AI-assisted contributions are welcome for other issues**, provided that: + +* You understand and can explain every change you submit. Maintainers may ask + questions during review; if you cannot answer them the PR will be closed. +* You disclose AI assistance in the PR description (a one-line note is fine). +* The code meets the same quality bar as any other contribution — correct, + tested, and consistent with the existing style. + +**AI must not be used to bulk-generate issues, comments, or spam.** +Automated issue creation or low-effort AI-generated content will be removed and +may result in being blocked from the repository. + +The spirit of this policy is simple: AI is a tool, not a substitute for +understanding. We want contributions that improve UltraPlot *and* grow the +contributor. + .. _contrib_release: Release procedure diff --git a/docs/examples/legend_types/01_semantic_legends.py b/docs/examples/legend_types/01_semantic_legends.py new file mode 100644 index 000000000..470bbe93b --- /dev/null +++ b/docs/examples/legend_types/01_semantic_legends.py @@ -0,0 +1,5 @@ +""" +Semantic legends +================ +With UltraPlot semantic legends can be expressed in a flexible and cohesive manner with customg glyphs, latex and or spatial locations. +""" diff --git a/docs/examples/legends_colorbars/03_semantic_legends.py b/docs/examples/legends_colorbars/03_semantic_legends.py index a869b826e..2cd124ece 100644 --- a/docs/examples/legends_colorbars/03_semantic_legends.py +++ b/docs/examples/legends_colorbars/03_semantic_legends.py @@ -6,18 +6,117 @@ Why UltraPlot here? ------------------- -UltraPlot adds semantic legend helpers directly on axes: +UltraPlot adds semantic legend helpers on both axes and figures: ``entrylegend``, ``catlegend``, ``sizelegend``, ``numlegend``, and ``geolegend``. These are useful when you want legend meaning decoupled from plotted handles, or when you want a standalone semantic key that describes an encoding directly. -Key functions: :py:meth:`ultraplot.axes.Axes.entrylegend`, :py:meth:`ultraplot.axes.Axes.catlegend`, :py:meth:`ultraplot.axes.Axes.sizelegend`, :py:meth:`ultraplot.axes.Axes.numlegend`, :py:meth:`ultraplot.axes.Axes.geolegend`. +Key functions: :py:meth:`ultraplot.axes.Axes.entrylegend`, :py:meth:`ultraplot.axes.Axes.catlegend`, :py:meth:`ultraplot.axes.Axes.sizelegend`, :py:meth:`ultraplot.axes.Axes.numlegend`, :py:meth:`ultraplot.axes.Axes.geolegend`, :py:meth:`ultraplot.figure.Figure.entrylegend`, :py:meth:`ultraplot.figure.Figure.catlegend`, :py:meth:`ultraplot.figure.Figure.sizelegend`, :py:meth:`ultraplot.figure.Figure.numlegend`, :py:meth:`ultraplot.figure.Figure.geolegend`. See also -------- * :doc:`Colorbars and legends ` """ +# %% +# Semantic Legend with custom markers and advanced styles +import matplotlib.transforms as mtransforms +import numpy as np +from matplotlib.markers import CapStyle, JoinStyle, MarkerStyle +from matplotlib.path import Path + +import ultraplot as uplt + +star = Path.unit_regular_star(6) +circle = Path.unit_circle() +star_path = Path.unit_regular_star(5) +cut_star = Path( + vertices=np.concatenate([circle.vertices, star.vertices[::-1, ...]]), + codes=np.concatenate([circle.codes, star.codes]), +) + +fig, ax = uplt.subplots() + +# upper left legend with custom mark +ax.catlegend( + ["star", "cus_star"], + marker=[star_path, cut_star], + markersize=10, + add=True, + loc="ul", + title="Paths", + ncols=1, +) + +# upper right legend with advanced CapStyle and JoinStyle +ax.catlegend( + ["butt / round", "round / miter", "projecting / bevel"], + marker="1", + markersize=10, + markeredgecolor=list("gbr"), + markeredgewidth=4, + markerfacecoloralt="none", + marker_capstyle=[ + CapStyle.butt, + CapStyle.round, + CapStyle.projecting, + ], + marker_joinstyle=[ + JoinStyle.round, + JoinStyle.miter, + JoinStyle.bevel, + ], + marker_transform=[mtransforms.Affine2D().rotate_deg(x) for x in [0, 30, 60]], + title="Cap & Join Style", + add=True, + loc="ur", + ncols=1, +) + +# center geolegend with different styles +ax.geolegend( + ["rect", "tri", "hex", "AU"], + facecolor=["tab:red", "r", "k", "tab:blue"], + ec=["k", "g", "orange", "bright pink"], + loc="c", + title="geolegend", + ew=[0.5, 2, 1, 0.5], + markersize=10, + ncols=4, + handletextpad=0.1, + columnspacing=0.7, +) + +# lower left legend with TeX symbols and rotation transform +ax.catlegend( + ["\\infty", "\\sum", "\\int"], + marker=[r"$\infty$", r"$\sum$", r"$\int$"], + s=[6, 18, 9], # ms/markersize=[6,8,10] + title="TeX symbols\nwith rotation", + marker_transform=[mtransforms.Affine2D().rotate_deg(x) for x in [30, 90, 45]], + add=True, + loc="ll", + ncols=1, +) + +# lower right legend with different fill style +ax.catlegend( + ["top", "bottom", "left", "right"], + marker="o", + markersize=10, + mfc=["r", "g", "b", "c"], + markerfacecoloralt="lightsteelblue", + markeredgecolor=["k", "r", "y", "b"], + fillstyle=["top", "bottom", "left", "right"], + title="Half filled", + add=True, + loc="lr", + ncols=1, +) +ax.axis("off") +fig.show() + + # %% import cartopy.crs as ccrs import shapely.geometry as sg @@ -106,3 +205,29 @@ ) ax.axis("off") fig.show() + +# %% +fig, axs = uplt.subplots(ncols=2, refwidth=2.8, share=False) +axs[0].scatter([0, 1, 2], [3, 1, 2], c=[0.2, 0.5, 0.8], s=[40, 120, 260]) +axs[1].scatter([0, 1, 2], [2, 3, 1], c=[0.8, 0.4, 0.1], s=[60, 90, 220]) +axs.format(title="Figure semantic legend helpers") + +fig.catlegend( + ["Control", "Treatment"], + colors={"Control": "blue7", "Treatment": "red7"}, + markers={"Control": "o", "Treatment": "^"}, + ref=axs, + loc="bottom", + title="Group", + frameon=False, +) +fig.sizelegend( + [40, 120, 260], + labels=["Small", "Medium", "Large"], + color="gray6", + ref=axs, + loc="right", + title="Size scale", + frameon=False, +) +fig.show() diff --git a/ultraplot/axes/_formatting.py b/ultraplot/axes/_formatting.py new file mode 100644 index 000000000..298724d00 --- /dev/null +++ b/ultraplot/axes/_formatting.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +""" +Shared metadata for axis formatting keyword routing and persistence. +""" + +import inspect + +_AXIS_STYLE_FIELD_TEMPLATES = { + "color": ("{axis}color", "color", "{axis}ec", "ec", "{axis}edgecolor", "edgecolor"), + "linewidth": ("{axis}linewidth", "linewidth", "{axis}lw", "lw"), + "rotation": ("{axis}rotation", "rotation"), + "spineloc": ("{axis}spineloc", "{axis}loc"), + "tickloc": ("{axis}tickloc",), + "ticklabelloc": ("{axis}ticklabelloc",), + "labelloc": ("{axis}labelloc",), + "offsetloc": ("{axis}offsetloc",), + "grid": ("{axis}grid",), + "gridminor": ("{axis}gridminor",), + "gridcolor": ("{axis}gridcolor", "gridcolor"), + "tickdir": ("{axis}tickdir", "tickdir"), + "tickcolor": ("{axis}tickcolor", "tickcolor"), + "ticklen": ("{axis}ticklen", "ticklen"), + "ticklenratio": ("{axis}ticklenratio", "ticklenratio"), + "tickwidth": ("{axis}tickwidth", "tickwidth"), + "tickwidthratio": ("{axis}tickwidthratio", "tickwidthratio"), + "ticklabeldir": ("{axis}ticklabeldir", "ticklabeldir"), + "ticklabelpad": ("{axis}ticklabelpad",), + "ticklabelcolor": ("{axis}ticklabelcolor", "ticklabelcolor"), + "ticklabelsize": ("{axis}ticklabelsize", "ticklabelsize"), + "ticklabelweight": ("{axis}ticklabelweight", "ticklabelweight"), + "labelpad": ("{axis}labelpad",), + "labelcolor": ("{axis}labelcolor", "labelcolor"), + "labelsize": ("{axis}labelsize", "labelsize"), + "labelweight": ("{axis}labelweight", "labelweight"), +} + + +def _dedupe(items): + return tuple(dict.fromkeys(items)) + + +GENERIC_AXIS_FORMAT_KEYS = _dedupe( + name + for names in _AXIS_STYLE_FIELD_TEMPLATES.values() + for name in names + if "{axis}" not in name +) + + +CARTESIAN_PARENT_FILTER_KEYS = GENERIC_AXIS_FORMAT_KEYS + ( + "label_kw", + "scale_kw", + "locator_kw", + "formatter_kw", + "minorlocator_kw", +) + + +def get_axis_style_fields(axis): + """ + Return the parameter names used to store explicit style overrides. + """ + return { + field: tuple(name.format(axis=axis) for name in names) + for field, names in _AXIS_STYLE_FIELD_TEMPLATES.items() + } + + +def _signature_param_names(*funcs): + names = [] + for func in funcs: + if isinstance(func, inspect.Signature): + sig = func + elif callable(func): + sig = inspect.signature(func) + elif func is None: + continue + else: + raise RuntimeError(f"Internal error. Invalid function {func!r}.") + names.extend(sig.parameters) + return set(names) + + +def pop_axis_format_kwargs(kwargs, *funcs): + """ + Pop axis-format kwargs so they survive rc parsing. + + Returns + ------- + tuple(dict, dict) + The signature-defined keyword arguments and the generic alias keyword + arguments that are not represented in the stored signatures. + """ + signature_keys = _signature_param_names(*funcs) + signature_kwargs = {} + generic_kwargs = {} + for key in tuple(kwargs): + if key in GENERIC_AXIS_FORMAT_KEYS: + generic_kwargs[key] = kwargs.pop(key) + elif key in signature_keys: + signature_kwargs[key] = kwargs.pop(key) + return signature_kwargs, generic_kwargs diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 29a962e14..96bff909b 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -1347,22 +1347,24 @@ def shared(paxs): # External axes sharing, sometimes overrides panel axes sharing # Share x axes within compatible groups - axes_x = self._get_share_axes("x") - for group in self.figure._partition_share_axes(axes_x, "x"): - if not group: - continue - parent, *children = group - for child in children: - child._sharex_setup(parent) + if self.figure._sharex > 0: + axes_x = self._get_share_axes("x") + for group in self.figure._partition_share_axes(axes_x, "x"): + if not group: + continue + parent, *children = group + for child in children: + child._sharex_setup(parent) # Share y axes within compatible groups - axes_y = self._get_share_axes("y") - for group in self.figure._partition_share_axes(axes_y, "y"): - if not group: - continue - parent, *children = group - for child in children: - child._sharey_setup(parent) + if self.figure._sharey > 0: + axes_y = self._get_share_axes("y") + for group in self.figure._partition_share_axes(axes_y, "y"): + if not group: + continue + parent, *children = group + for child in children: + child._sharey_setup(parent) # Global sharing, use the reference subplot where compatible ref = self.figure._subplot_dict.get(self.figure._refnum, None) @@ -2127,14 +2129,11 @@ def _legend_label(*objs): # noqa: E301 # Helper function. Translate handles in the input tuple group. Extracts # legend handles from contour sets and extracts labeled elements from # matplotlib containers (important for histogram plots). - ignore = (mcontainer.ErrorbarContainer,) containers = (cbook.silent_list, mcontainer.Container) def _legend_tuple(*objs): # noqa: E306 handles = [] for obj in objs: - if isinstance(obj, ignore) and not _legend_label(obj): - continue if hasattr(obj, "update_scalarmappable"): # for e.g. pcolor obj.update_scalarmappable() if isinstance(obj, mcontour.ContourSet): # extract single element @@ -2143,7 +2142,9 @@ def _legend_tuple(*objs): # noqa: E306 if hs: # non-empty obj = hs[len(hs) // 2] obj.set_label(label) - if isinstance(obj, containers): # extract labeled elements + if isinstance(obj, mcontainer.ErrorbarContainer): + handles.append(obj) + elif isinstance(obj, containers): # extract labeled elements hs = (obj, *guides._iter_iterables(obj)) hs = tuple(filter(_legend_label, hs)) if hs: @@ -2928,10 +2929,18 @@ def _update_title_position(self, renderer): if base_x >= ax1 + abc_title_sep: max_width = base_x - (ax1 + abc_title_sep) elif ha == "center": + # Keep the requested font size for centered titles and + # resolve collisions by shifting the title away from the + # abc label, matching the overflow-tolerant behavior of + # left/right titles in practice. if base_x >= ax1 + abc_title_sep: - max_width = 2 * (base_x - (ax1 + abc_title_sep)) + shift = (ax1 + abc_title_sep) - tx0 + if shift > 0: + title_obj.set_x(base_x + shift) elif base_x <= ax0 - abc_title_sep: - max_width = 2 * ((ax0 - abc_title_sep) - base_x) + shift = (ax0 - abc_title_sep) - tx1 + if shift < 0: + title_obj.set_x(base_x + shift) if 0 < max_width < title_bbox.width: scale = max_width / title_bbox.width title_obj.set_fontsize(title_obj.get_fontsize() * scale) @@ -3590,78 +3599,207 @@ def legend( **kwargs, ) + @docstring._snippet_manager def catlegend(self, categories, **kwargs): """ - Build categorical legend entries and optionally add a legend. + Build a categorical legend — one handle per unique category — and + optionally draw it. Parameters ---------- - categories - Category labels used to generate legend handles. - **kwargs - Forwarded to `ultraplot.legend.UltraLegend.catlegend`. - Pass ``add=False`` to return ``(handles, labels)`` without drawing. + categories : iterable + Category labels in display order. Duplicates are collapsed; the + first occurrence determines position. + color, marker + %(legend.semantic_style_arg)s + Defaults to ultraplot's color cycle for ``color`` and ``"o"`` for + ``marker`` (or :rc:`legend.cat.marker` when set). + line : bool, optional + Whether to render connector lines through the markers. Falls back + to :rc:`legend.cat.line`. Setting a non-default ``linestyle`` + implicitly enables this. + + Other parameters + ---------------- + %(legend.semantic_style_kwargs)s + %(legend.semantic_handle_kw)s + + See also + -------- + Axes.entrylegend + Axes.sizelegend """ return plegend.UltraLegend(self).catlegend(categories, **kwargs) + @docstring._snippet_manager def entrylegend(self, entries, **kwargs): """ - Build generic semantic legend entries and optionally add a legend. + Build generic semantic legend entries from explicit ``{label: style}`` + entries and optionally draw the legend. Parameters ---------- - entries - Entry specifications as handles, style dictionaries, or ``(label, spec)`` - pairs. - **kwargs - Forwarded to `ultraplot.legend.UltraLegend.entrylegend`. - Pass ``add=False`` to return ``(handles, labels)`` without drawing. + entries : iterable or mapping + Entry specifications. Either a sequence of ``{**style_kwargs}`` + dicts (each requiring at least ``label``) or a mapping from label + to style-kwargs dict. + line : bool, optional + Whether each entry shows a connector line. Falls back to + :rc:`legend.cat.line`. + marker, color + %(legend.semantic_style_arg)s + + Other parameters + ---------------- + %(legend.semantic_style_kwargs)s + %(legend.semantic_handle_kw)s + + See also + -------- + Axes.catlegend + Axes.sizelegend """ return plegend.UltraLegend(self).entrylegend(entries, **kwargs) + @docstring._snippet_manager def sizelegend(self, levels, **kwargs): """ - Build size legend entries and optionally add a legend. + Build a size legend — one handle per level, scaled by marker size — + and optionally draw it. Parameters ---------- - levels - Numeric levels used to generate marker-size entries. - **kwargs - Forwarded to `ultraplot.legend.UltraLegend.sizelegend`. - Pass ``labels=[...]`` or ``labels={level: label}`` to override the - generated labels. - Pass ``add=False`` to return ``(handles, labels)`` without drawing. + levels : iterable of float + Numeric values to render as size-scaled markers. + labels : iterable or mapping, optional + Custom labels. A mapping ``{level: label}`` overrides individual + entries (every level must be a key). When omitted, labels are + formatted from ``levels`` via ``fmt``. + color, marker + %(legend.semantic_style_arg)s + Defaults to :rc:`legend.size.color` and :rc:`legend.size.marker`. + area : bool, optional + Treat ``levels`` as marker areas (``True``, default) or + diameters (``False``). Areas are converted with + ``ms = sqrt(level) * scale``. Falls back to :rc:`legend.size.area`. + scale : float, optional + Multiplier applied after area/diameter conversion. + Falls back to :rc:`legend.size.scale`. + minsize : float, optional + Lower bound on rendered marker size. + Falls back to :rc:`legend.size.minsize`. + fmt : str or callable, optional + Format used to label levels. Falls back to :rc:`legend.size.format`. + + Other parameters + ---------------- + %(legend.semantic_style_kwargs)s + %(legend.semantic_handle_kw)s + + See also + -------- + Axes.catlegend + Axes.numlegend """ return plegend.UltraLegend(self).sizelegend(levels, **kwargs) + @docstring._snippet_manager def numlegend(self, levels=None, **kwargs): """ - Build numeric-color legend entries and optionally add a legend. + Build a numeric legend — one patch handle per level, colored from a + colormap — and optionally draw it. Parameters ---------- - levels - Numeric levels or number of levels. - **kwargs - Forwarded to `ultraplot.legend.UltraLegend.numlegend`. - Pass ``add=False`` to return ``(handles, labels)`` without drawing. + levels : iterable of float, optional + Numeric levels to render. When omitted, ``n`` evenly spaced + levels are derived from ``vmin`` / ``vmax``. + vmin, vmax : float, optional + Limits for sampling ``cmap`` when ``norm`` is not provided. + n : int, optional + Number of levels to sample when ``levels`` is omitted. + Falls back to :rc:`legend.num.n`. + cmap : str or `~matplotlib.colors.Colormap`, optional + Colormap used to color the patches. + Falls back to :rc:`legend.num.cmap`. + norm : `~matplotlib.colors.Normalize`, optional + Normalization applied to ``levels`` before colormap lookup. + fmt : str or callable, optional + Format used to label levels. + Falls back to :rc:`legend.num.format`. + facecolor, edgecolor + %(legend.semantic_style_arg)s + ``facecolor`` defaults to colormap-derived values; ``edgecolor`` + falls back to :rc:`legend.num.edgecolor`. + linewidth, linestyle, alpha + Patch outline width, style, and transparency. ``linewidth`` / + ``alpha`` fall back to :rc:`legend.num.linewidth` / + :rc:`legend.num.alpha`. + + Other parameters + ---------------- + %(legend.semantic_num_style_kwargs)s + %(legend.semantic_handle_kw)s + + See also + -------- + Axes.sizelegend + Axes.geolegend """ return plegend.UltraLegend(self).numlegend(levels=levels, **kwargs) + @docstring._snippet_manager def geolegend(self, entries, labels=None, **kwargs): """ - Build geometry legend entries and optionally add a legend. + Build a geometry legend — one patch handle per geometry entry — and + optionally draw it. Parameters ---------- - entries - Geometry entries (mapping, ``(label, geometry)`` pairs, or geometries). - labels - Optional labels for geometry sequences. - **kwargs - Forwarded to `ultraplot.legend.UltraLegend.geolegend`. - Pass ``add=False`` to return ``(handles, labels)`` without drawing. + entries : iterable or mapping + Either a sequence of ``(label, geometry)`` pairs or a mapping + from label to geometry specification (string keyword, shapely + geometry, ``cartopy`` feature, or a country name when + ``country_reso`` is set). + labels : iterable, optional + Labels overriding those derived from ``entries``. + country_reso : str, optional + Natural Earth resolution for country geometries (e.g. ``"110m"``). + Falls back to :rc:`legend.geo.country_reso`. + country_territories : bool, optional + Whether country lookups include overseas territories. + Falls back to :rc:`legend.geo.country_territories`. + country_proj : any, optional + Projection used to render country geometries; ignored for non- + country entries. Falls back to :rc:`legend.geo.country_proj`. + handlesize : float, optional + Multiplier applied to legend ``handlelength`` / ``handleheight`` + to enlarge geometry handles. Falls back to + :rc:`legend.geo.handlesize`. Must be positive. + facecolor, edgecolor + %(legend.semantic_style_arg)s + Default to :rc:`legend.geo.facecolor` / :rc:`legend.geo.edgecolor`. + linewidth, alpha, fill + Patch outline width, transparency, and fill toggle. + Defaults from :rc:`legend.geo.linewidth` / :rc:`legend.geo.alpha` / + :rc:`legend.geo.fill`. + + Other parameters + ---------------- + %(legend.semantic_num_style_kwargs)s + %(legend.semantic_handle_kw)s + + Notes + ----- + Geometry legend entries use normalized patch proxies inside the legend + handle box rather than reusing the original map artist directly. This + preserves the general geometry shape and copied patch styling, but very + small or high-aspect-ratio handles can still make hatches difficult to + read at legend scale. + + See also + -------- + Axes.numlegend """ return plegend.UltraLegend(self).geolegend(entries, labels=labels, **kwargs) diff --git a/ultraplot/axes/cartesian.py b/ultraplot/axes/cartesian.py index 696639beb..677f81989 100644 --- a/ultraplot/axes/cartesian.py +++ b/ultraplot/axes/cartesian.py @@ -4,6 +4,7 @@ """ import copy +import functools import inspect from dataclasses import dataclass, field from typing import Any, Dict, Optional, Tuple, Union @@ -20,6 +21,7 @@ from ..config import rc from ..internals import ( _not_none, + _pop_params, _pop_rc, _version_mpl, docstring, @@ -28,6 +30,11 @@ warnings, ) from ..utils import units +from ._formatting import ( + CARTESIAN_PARENT_FILTER_KEYS, + get_axis_style_fields, + pop_axis_format_kwargs, +) from . import plot, shared __all__ = ["CartesianAxes"] @@ -431,6 +438,8 @@ def __init__(self, *args, **kwargs): self._yaxis_current_rotation = "horizontal" self._xaxis_isdefault_rotation = True # whether to auto rotate the axis self._yaxis_isdefault_rotation = True + self._xaxis_style_state = {} + self._yaxis_style_state = {} super().__init__(*args, **kwargs) # Apply default formatter @@ -447,6 +456,37 @@ def __init__(self, *args, **kwargs): self._dualy_funcscale = None self._dualy_prevstate = None + def _get_axis_style_state(self, axis): + """ + Return the cached explicit style overrides for this axis. + """ + return getattr(self, f"_{axis}axis_style_state") + + def _merge_axis_style_state(self, axis, params): + """ + Merge the current explicit style overrides with the cached overrides. + """ + state = self._get_axis_style_state(axis).copy() + explicit_keys = set(params.get("_explicit_format_keys", ())) + for field, names in get_axis_style_fields(axis).items(): + if any(name in explicit_keys for name in names) and all( + params.get(name, None) is None for name in names + ): + state.pop(field, None) + continue + value = _not_none(*(params.get(name) for name in names)) + if value is not None: + state[field] = value + return state + + def _set_axis_style_state(self, axis, params): + """ + Cache the explicit style overrides for this axis. + """ + setattr( + self, f"_{axis}axis_style_state", self._merge_axis_style_state(axis, params) + ) + def _apply_axis_sharing(self): """ Enforce the "shared" axis labels and axis tick labels. If this is not @@ -1204,13 +1244,12 @@ def _format_axis(self, s: str, config: _AxisFormatConfig, fixticks: bool): self.margins(**{s: config.margin}) # Axis spine settings - # NOTE: This sets spine-specific color and linewidth settings. For - # non-specific settings _update_background is called in Axes.format() self._update_spines(s, loc=config.spineloc, bounds=config.bounds) - self._update_background( + self._update_frame( s, edgecolor=config.color, linewidth=config.linewidth, + tickcolor=config.tickcolor, tickwidth=tickwidth, tickwidthratio=config.tickwidthratio, ) @@ -1297,27 +1336,84 @@ def _resolve_axis_format(self, axis, params, rc_kw): Resolve formatting parameters for a single axis (x or y). """ p = params - - # Color resolution - color = p.get("color") - axis_color = _not_none(p.get(f"{axis}color"), color) + prev = self._merge_axis_style_state(axis, p) # Helper to get axis-specific or generic param def get(name): - return p.get(f"{axis}{name}") + return _not_none(p.get(f"{axis}{name}"), p.get(name)) + + # Color resolution + axis_color_arg = prev.get("color", None) + axis_color = _not_none( + axis_color_arg, + rc.find("axes.edgecolor", context=True), + rc["axes.edgecolor"], + ) + linewidth = _not_none( + prev.get("linewidth", None), + rc.find("axes.linewidth", context=True), + rc["axes.linewidth"], + ) # Resolve colors tickcolor = get("tickcolor") if "tick.color" not in rc_kw: - tickcolor = _not_none(tickcolor, axis_color) + tickcolor = _not_none( + prev.get("tickcolor", None), + axis_color_arg, + rc.find(f"{axis}tick.color", context=True), + rc[f"{axis}tick.color"], + ) ticklabelcolor = get("ticklabelcolor") if "tick.labelcolor" not in rc_kw: - ticklabelcolor = _not_none(ticklabelcolor, axis_color) + ticklabelcolor = _not_none( + prev.get("ticklabelcolor", None), + axis_color_arg, + ) labelcolor = get("labelcolor") if "label.color" not in rc_kw: - labelcolor = _not_none(labelcolor, axis_color) + labelcolor = _not_none( + prev.get("labelcolor", None), + axis_color_arg, + ) + + ticklen = _not_none( + get("ticklen"), + prev.get("ticklen", None), + rc.find("tick.len", context=True), + rc["tick.len"], + ) + ticklenratio = _not_none( + get("ticklenratio"), + prev.get("ticklenratio", None), + rc.find("tick.lenratio", context=True), + rc["tick.lenratio"], + ) + tickwidth = _not_none( + get("tickwidth"), + prev.get("tickwidth", None), + prev.get("linewidth", None), + rc.find("tick.width", context=True), + rc["tick.width"], + ) + tickwidthratio = _not_none( + get("tickwidthratio"), + prev.get("tickwidthratio", None), + rc.find("tick.widthratio", context=True), + rc["tick.widthratio"], + ) + ticklabelsize = prev.get("ticklabelsize", None) + ticklabelweight = prev.get("ticklabelweight", None) + labelsize = prev.get("labelsize", None) + labelweight = prev.get("labelweight", None) + grid = prev.get("grid", None) + gridminor = prev.get("gridminor", None) + gridcolor = prev.get("gridcolor", None) + rotation = prev.get("rotation", None) + ticklabelpad = prev.get("ticklabelpad", None) + labelpad = prev.get("labelpad", None) # Flexible keyword args margin = _not_none( @@ -1325,7 +1421,8 @@ def get(name): ) tickdir = _not_none( - get("tickdir"), rc.find(f"{axis}tick.direction", context=True) + prev.get("tickdir", None), + rc.find(f"{axis}tick.direction", context=True), ) locator = _not_none(get("locator"), p.get(f"{axis}ticks")) @@ -1345,31 +1442,32 @@ def get(name): tickminor = _not_none( tickminor, + prev.get("tickminor", None), tickminor_default, rc.find(f"{axis}tick.minor.visible", context=True), ) # Tick label dir logic - ticklabeldir = p.get("ticklabeldir") + ticklabeldir = prev.get("ticklabeldir", None) axis_ticklabeldir = _not_none(get("ticklabeldir"), ticklabeldir) tickdir = _not_none(tickdir, axis_ticklabeldir) # Spine locations loc = get("loc") - spineloc = get("spineloc") + spineloc = prev.get("spineloc", None) spineloc = _not_none(loc, spineloc) # Spine side inference side = self._get_spine_side(axis, spineloc) - tickloc = get("tickloc") + tickloc = prev.get("tickloc", None) if side is not None and side not in ("zero", "center", "both"): tickloc = _not_none(tickloc, side) # Infer other locations - ticklabelloc = get("ticklabelloc") - labelloc = get("labelloc") - offsetloc = get("offsetloc") + ticklabelloc = prev.get("ticklabelloc", None) + labelloc = prev.get("labelloc", None) + offsetloc = prev.get("offsetloc", None) if tickloc != "both": ticklabelloc = _not_none(ticklabelloc, tickloc) @@ -1396,16 +1494,42 @@ def get(name): val = p.get(f"{axis}max") case "color": val = axis_color + case "linewidth": + val = linewidth case "tickcolor": val = tickcolor + case "ticklen": + val = ticklen + case "ticklenratio": + val = ticklenratio + case "tickwidth": + val = tickwidth + case "tickwidthratio": + val = tickwidthratio case "ticklabelcolor": val = ticklabelcolor + case "ticklabelsize": + val = ticklabelsize + case "ticklabelweight": + val = ticklabelweight case "labelcolor": val = labelcolor + case "labelsize": + val = labelsize + case "labelweight": + val = labelweight case "margin": val = margin case "tickdir": val = tickdir + case "grid": + val = grid + case "gridminor": + val = gridminor + case "gridcolor": + val = gridcolor + case "rotation": + val = rotation case "locator": val = locator case "minorlocator": @@ -1416,6 +1540,8 @@ def get(name): val = tickminor case "ticklabeldir": val = axis_ticklabeldir + case "ticklabelpad": + val = ticklabelpad case "spineloc": val = spineloc case "tickloc": @@ -1426,6 +1552,8 @@ def get(name): val = labelloc case "offsetloc": val = offsetloc + case "labelpad": + val = labelpad case _: # Direct mapping (e.g. xlinewidth -> linewidth) val = get(field) @@ -1569,11 +1697,26 @@ def format( or `datetime.datetime` array as the x or y axis coordinate, the axis ticks and tick labels will be automatically formatted as dates. """ + explicit_format_keys = set(kwargs) + explicit_format_keys.update(kwargs.pop("_explicit_format_keys", ())) + signature_axis_kwargs, generic_axis_kwargs = pop_axis_format_kwargs( + kwargs, self._format_signatures[CartesianAxes] + ) + explicit_format_keys.update(signature_axis_kwargs) + explicit_format_keys.update(generic_axis_kwargs) rc_kw, rc_mode = _pop_rc(kwargs) + kwargs.update(signature_axis_kwargs) + kwargs.update(generic_axis_kwargs) + base_kwargs = kwargs.copy() + _pop_params(base_kwargs, self._format_signatures[CartesianAxes]) + for key in CARTESIAN_PARENT_FILTER_KEYS: + base_kwargs.pop(key, None) + with rc.context(rc_kw, mode=rc_mode): # Resolve parameters for x and y axes # We capture locals() to pass all named arguments to the helper params = locals() + params["_explicit_format_keys"] = explicit_format_keys params.update(kwargs) # Include any extras in kwargs x_config = self._resolve_axis_format("x", params, rc_kw) @@ -1582,6 +1725,8 @@ def format( # Format axes self._format_axis("x", x_config, fixticks=fixticks) self._format_axis("y", y_config, fixticks=fixticks) + self._set_axis_style_state("x", params) + self._set_axis_style_state("y", params) if rc.find("formatter.log", context=True): if ( @@ -1603,10 +1748,9 @@ def format( ): self._update_formatter("y", "log") - # Parent format method if aspect is not None: self.set_aspect(aspect) - super().format(rc_kw=rc_kw, rc_mode=rc_mode, **kwargs) + super().format(rc_kw=rc_kw, rc_mode=rc_mode, **base_kwargs) @docstring._snippet_manager def altx(self, **kwargs): @@ -1678,10 +1822,24 @@ def get_tightbbox(self, renderer, *args, **kwargs): return super().get_tightbbox(renderer, *args, **kwargs) +def _capture_explicit_format_keys(func): + """ + Preserve raw keyword names before Python binds them to the format signature. + """ + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + kwargs.setdefault("_explicit_format_keys", set(kwargs)) + return func(self, *args, **kwargs) + + return wrapper + + # tmp # Apply signature obfuscation after storing previous signature # NOTE: This is needed for __init__, altx, and alty CartesianAxes._format_signatures[CartesianAxes] = inspect.signature( CartesianAxes.format ) # noqa: E501 +CartesianAxes.format = _capture_explicit_format_keys(CartesianAxes.format) CartesianAxes.format = docstring._obfuscate_kwargs(CartesianAxes.format) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 39e55a896..e9078a1a4 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -6057,7 +6057,9 @@ def _apply_bar( kw = self._parse_cycle(n, **kw) # Adjust x or y coordinates for grouped and stacked bars w = _not_none(w, np.array([0.8])) # same as mpl but in *relative* units - b = _not_none(b, np.array([0.0])) # same as mpl + b = np.atleast_1d( + _not_none(b, np.array([0.0])) + ) # tolerate scalar `bottom`/`left` if not absolute_width: w = self._convert_bar_width(x, w) if stack: diff --git a/ultraplot/axes/polar.py b/ultraplot/axes/polar.py index bf62e010c..c56b2bc98 100644 --- a/ultraplot/axes/polar.py +++ b/ultraplot/axes/polar.py @@ -11,17 +11,30 @@ from typing_extensions import override import matplotlib.projections.polar as mpolar +import matplotlib.transforms as mtransforms import numpy as np +from matplotlib.font_manager import FontProperties from .. import constructor from .. import ticker as pticker from ..config import rc -from ..internals import ic # noqa: F401 -from ..internals import _not_none, _pop_rc, docstring +from ..internals import ( + _not_none, + _pop_rc, + docstring, + ic, # noqa: F401 +) from . import plot, shared __all__ = ["PolarAxes"] +# CurvedText sampling resolution along the label arc / spoke. +_POLAR_LABEL_NPOINTS = 50 +# Half-span (degrees) used when the label sits on a closed (full) circle. +_POLAR_LABEL_FULL_HALFSPAN_DEG = 15.0 +# Fraction of an open sector occupied by `thetalabel`; remainder is endpoint margin. +_POLAR_LABEL_SECTOR_FRAC = 0.8 + # Format docstring _format_docstring = """ @@ -71,7 +84,9 @@ thetaminorlocator_kw, rminorlocator_kw As for `thetalocator_kw`, `rlocator_kw`, but for the minor locator. rlabelpos : float, optional - The azimuth at which radial coordinates are labeled. + The azimuth at which radial coordinates are labeled. Also used as the + spoke angle for ``rlabel`` when you want an explicit radial-label + position. thetaformatter, rformatter : formatter-spec, optional Used to determine the azimuthal and radial label format. Passed to the `~ultraplot.constructor.Formatter` constructor. @@ -82,6 +97,32 @@ thetaformatter_kw, rformatter_kw : dict-like, optional The azimuthal and radial label formatter settings. Passed to `~ultraplot.constructor.Formatter`. +thetalabel, rlabel : str, optional + Polar-aware axis labels rendered via `~ultraplot.text.CurvedText`. + ``thetalabel`` follows the outer arc just beyond ``r=rmax``. + ``rlabel`` follows a radial spoke, centered between ``rmin`` and + ``rmax``. On a full circle it uses ``get_rlabel_position()`` unless + ``rlabelpos`` is explicit; on a sector it uses the spoke selected by + ``rlabelloc`` unless ``rlabelpos`` is explicit. Both labels include a + built-in tick-clearance offset, and ``labelpad`` adds extra padding in + points on top of that offset. Pass ``""`` to clear a previously set + label. +thetalabelloc : float, optional + Center theta angle (in degrees) for ``thetalabel``. Defaults to the + midpoint of the directed ``thetalim`` interval (or ``0`` for a full + circle). +rlabelloc : {'right', 'left'}, default: 'right' + Where to place ``rlabel``. When the spoke angle is fixed by a full + circle or by explicit ``rlabelpos``, ``rlabelloc`` selects the + perpendicular side of that spoke and ``'left'`` flips the default + side. On a sector with no explicit ``rlabelpos``, ``'right'`` + (default) anchors to ``thetamin`` and ``'left'`` anchors to + ``thetamax``; the label is then offset outward from the sector. +thetalabel_kw, rlabel_kw : dict-like, optional + Additional `~ultraplot.text.CurvedText` settings for the polar-aware + labels (e.g. ``border``, ``bbox``, or rendering hints like + ``min_advance``). See also `labelpad`, `labelcolor`, `labelsize`, + and `labelweight`. color : color-spec, default: :rc:`meta.color` Color for the axes edge. Propagates to `labelcolor` unless specified otherwise (similar to :func:`~ultraplot.axes.CartesianAxes.format`). @@ -89,6 +130,8 @@ Color for the gridline labels. labelpad, gridlabelpad : unit-spec, default: :rc:`grid.labelpad` The padding between the axes edge and the radial and azimuthal labels. + For ``thetalabel`` and ``rlabel``, this is added on top of the built-in + tick-clearance offset. %(units.pt)s labelsize, gridlabelsize : unit-spec or str, default: :rc:`grid.labelsize` Font size for the gridline labels. @@ -143,6 +186,8 @@ def __init__(self, *args, **kwargs): self.yaxis.isDefault_majfmt = True for axis in (self.xaxis, self.yaxis): axis.set_tick_params(which="both", size=0) + self._thetalabel_artist = None + self._rlabel_artist = None @override def _apply_axis_sharing(self): @@ -212,6 +257,248 @@ def _update_locators( else: axis.set_minor_locator(loc) + def _get_directed_thetalim(self): + """Return the directed theta interval in degrees from the raw x-limits.""" + thetamin, thetamax = np.rad2deg(self.get_xlim()) + return float(thetamin), float(thetamax) + + @staticmethod + def _is_full_circle_thetalim(thetamin, thetamax): + """Return whether the directed theta interval spans a full circle.""" + return np.isclose((thetamax - thetamin) % 360.0, 0.0) + + def _polar_tick_clearance_in(self, axis): + """Tick mark + tick pad + ~font height(s), in inches.""" + axis_obj = getattr(self, f"{axis}axis") + size_pt = rc[f"{axis}tick.major.size"] + pad_pt = rc[f"{axis}tick.major.pad"] + label_pt = FontProperties(size=rc[f"{axis}tick.labelsize"]).get_size_in_points() + ticks = axis_obj.get_major_ticks() + if ticks: + tick = ticks[0] + size_pt = max( + tick.tick1line.get_markersize(), tick.tick2line.get_markersize() + ) + pad_pt = ( + tick.get_pad() + if hasattr(tick, "get_pad") + else getattr(tick, "_pad", pad_pt) + ) + label_pt = max(tick.label1.get_size(), tick.label2.get_size(), label_pt) + labels = axis_obj.get_ticklabels() + if labels: + label_pt = max(float(label.get_size()) for label in labels) + n = 2 if axis == "x" else 1.5 + return (size_pt + pad_pt + n * label_pt) / 72.0 + + def _build_thetalabel_curve(self, loc, total_pad_in): + """ + Curve along the outer arc at r = rmax + delta_r (data coords). The + radial offset is computed in data space so clearance is angle- + independent — figure-space ScaledTranslation undershoots when the + outward direction points toward a tight bbox edge (e.g. 180–230°). + """ + thetamin, thetamax = self._get_directed_thetalim() + span = (thetamax - thetamin) % 360.0 + is_full_circle = self._is_full_circle_thetalim(thetamin, thetamax) + if is_full_circle: + mid = 0.0 if loc is None else float(loc) + half_span = _POLAR_LABEL_FULL_HALFSPAN_DEG + elif loc is None: + mid = thetamin + 0.5 * span + half_span = 0.5 * span * _POLAR_LABEL_SECTOR_FRAC + else: + # Explicit thetalabelloc on a sector: localize the label around + # the requested angle instead of spanning the whole sector arc. + mid = float(loc) + half_span = _POLAR_LABEL_FULL_HALFSPAN_DEG + x = np.deg2rad( + np.linspace(mid - half_span, mid + half_span, _POLAR_LABEL_NPOINTS) + ) + rmax_val = self.get_rmax() + p0 = self.transData.transform(np.array([0.0, rmax_val])) + p1 = self.transData.transform(np.array([0.0, rmax_val + 1.0])) + px_per_r = float(np.linalg.norm(np.asarray(p1) - np.asarray(p0))) + delta_r = total_pad_in * self.figure.dpi / px_per_r if px_per_r > 1e-6 else 0.0 + y = np.full_like(x, rmax_val + delta_r) + return x, y, self.transData + + def _get_sector_rlabel_outside_sign(self, rpos): + """Return the sign that offsets a sector rlabel outside the wedge.""" + thetamin, thetamax = self._get_directed_thetalim() + span = (thetamax - thetamin) % 360.0 + inside_step = min(1.0, 0.25 * span) + inside_theta = ( + rpos - inside_step + if np.isclose((rpos - thetamax) % 360.0, 0.0) + else rpos + inside_step + ) + rmid = 0.5 * (self.get_rmin() + self.get_rmax()) + edge = self.transData.transform(np.array([np.deg2rad(rpos), rmid])) + inside = self.transData.transform(np.array([np.deg2rad(inside_theta), rmid])) + normal = self._get_rlabel_right_normal(np.deg2rad(rpos)) + return ( + -1.0 if np.dot(np.asarray(inside) - np.asarray(edge), normal) > 0.0 else 1.0 + ) + + def _resolve_rlabel_geometry(self, loc, rlabelpos): + """ + Resolve ``(rpos, sign)`` for the radial label given ``rlabelloc`` and + an optional explicit ``rlabelpos``. On a full circle, ``loc`` flips + the perpendicular offset; on a sector with no explicit ``rlabelpos``, + ``loc`` instead selects the spoke (``thetamin`` vs ``thetamax``) and + the perpendicular sign is auto-chosen to fall outside the wedge. + """ + if loc not in (None, "left", "right"): + raise ValueError(f"rlabelloc must be 'right' or 'left'; got {loc!r}") + thetamin, thetamax = self._get_directed_thetalim() + is_full_circle = self._is_full_circle_thetalim(thetamin, thetamax) + if rlabelpos is not None: + rpos = float(rlabelpos) + if is_full_circle: + base_sign = 1.0 + else: + base_sign = -1.0 if np.isclose((rpos - thetamax) % 360.0, 0.0) else 1.0 + elif is_full_circle: + rpos = self.get_rlabel_position() + base_sign = 1.0 + else: + rpos = thetamax if loc == "left" else thetamin + base_sign = self._get_sector_rlabel_outside_sign(rpos) + flip = loc == "left" and (is_full_circle or rlabelpos is not None) + sign = -base_sign if flip else base_sign + return rpos, sign + + def _get_rlabel_right_normal(self, rad): + """Return the display-space right normal for the radial spoke at ``rad``.""" + rmin, rmax = self.get_rmin(), self.get_rmax() + p0 = self.transData.transform(np.array([rad, rmin])) + p1 = self.transData.transform(np.array([rad, rmax])) + tangent = np.asarray(p1, dtype=float) - np.asarray(p0, dtype=float) + norm = np.linalg.norm(tangent) + if norm <= 1e-6: + return np.array([np.sin(rad), -np.cos(rad)]) + tangent /= norm + return np.array([tangent[1], -tangent[0]]) + + def _build_rlabel_curve(self, loc, pad_in, rlabelpos): + """ + Curve along the radial spoke from rmin to rmax with a perpendicular + ScaledTranslation offset so the label clears the r-tick labels. + """ + rpos, sign = self._resolve_rlabel_geometry(loc, rlabelpos) + rad = np.deg2rad(rpos) + x = np.full(_POLAR_LABEL_NPOINTS, rad) + y = np.linspace(self.get_rmin(), self.get_rmax(), _POLAR_LABEL_NPOINTS) + normal = self._get_rlabel_right_normal(rad) + tick_clearance_in = self._polar_tick_clearance_in("y") + total_pad_in = pad_in + tick_clearance_in + dx_in, dy_in = sign * total_pad_in * normal + transform = self.transData + mtransforms.ScaledTranslation( + dx_in, dy_in, self.figure.dpi_scale_trans + ) + return x, y, transform + + def _refresh_polar_label_geometry(self, kind): + """Refresh the stored curve and transform for an existing polar label.""" + attr = f"_{kind}label_artist" + artist = getattr(self, attr, None) + if artist is None: + return + state = getattr(self, f"_{kind}label_state", None) or {} + loc = state.get("loc") + labelpad = state.get("labelpad") + pad_in = _not_none(labelpad, rc["grid.labelpad"]) / 72.0 + axis = "x" if kind == "theta" else "y" + total_pad_in = pad_in + self._polar_tick_clearance_in(axis) + if kind == "theta": + x, y, transform = self._build_thetalabel_curve(loc, total_pad_in) + else: + x, y, transform = self._build_rlabel_curve( + loc, pad_in, state.get("rlabelpos") + ) + artist.set_curve(x, y) + artist.set_transform(transform) + + def _update_polar_label( + self, kind, text, *, loc=None, labelpad=None, rlabelpos=None, **kwargs + ): + """ + Apply a polar-aware axis label along the outer arc (`thetalabel`) or + along the radial spoke (`rlabel`), both via CurvedText. + """ + # NOTE: Critical to test whether arguments are None or else we'd + # overwrite styling and clear text on every format() call. + kwargs = rc._get_label_props(**kwargs) + kwargs.pop("labelpad", None) # injected by _get_label_props; not a Text prop + attr = f"_{kind}label_artist" + artist = getattr(self, attr, None) + # Sticky state: previously-applied loc/labelpad/rlabelpos so a generic + # format() call (e.g. ``axs.format(suptitle=...)``) doesn't reset them + # back to the default when the user didn't pass them again. + state_attr = f"_{kind}label_state" + state = getattr(self, state_attr, None) or {} + nothing_to_do = ( + text is None + and loc is None + and labelpad is None + and rlabelpos is None + and all(v is None for v in kwargs.values()) + ) + if artist is None and nothing_to_do: + return + + if loc is not None: + state["loc"] = loc + if labelpad is not None: + state["labelpad"] = labelpad + if kind == "r" and rlabelpos is not None: + state["rlabelpos"] = rlabelpos + setattr(self, state_attr, state) + loc = state.get("loc") + labelpad = state.get("labelpad") + rlabelpos = state.get("rlabelpos") if kind == "r" else None + + pad_in = _not_none(labelpad, rc["grid.labelpad"]) / 72.0 + style_props = {k: v for k, v in kwargs.items() if v is not None} + if kind == "theta": + total_pad_in = pad_in + self._polar_tick_clearance_in("x") + x, y, transform = self._build_thetalabel_curve(loc, total_pad_in) + else: + x, y, transform = self._build_rlabel_curve(loc, pad_in, rlabelpos) + + if artist is None: + artist = self.text( + x, + y, + text or "", + transform=transform, + ha="center", + va="center", + clip_on=False, + **style_props, + ) + setattr(self, attr, artist) + return + artist.set_curve(x, y) + artist.set_transform(transform) + if text is not None: + artist.set_text(text) + if style_props: + artist._apply_label_props(style_props) + + @override + def draw(self, renderer=None, *args, **kwargs): + self._refresh_polar_label_geometry("theta") + self._refresh_polar_label_geometry("r") + super().draw(renderer, *args, **kwargs) + + @override + def get_tightbbox(self, renderer, *args, **kwargs): + self._refresh_polar_label_geometry("theta") + self._refresh_polar_label_geometry("r") + return super().get_tightbbox(renderer, *args, **kwargs) + @docstring._snippet_manager def format( self, @@ -256,6 +543,12 @@ def format( labelsize=None, labelcolor=None, labelweight=None, + thetalabel=None, + rlabel=None, + thetalabelloc=None, + rlabelloc=None, + thetalabel_kw=None, + rlabel_kw=None, **kwargs, ): """ @@ -282,6 +575,34 @@ def format( rc_kw, rc_mode = _pop_rc(kwargs) labelcolor = _not_none(labelcolor, kwargs.get("color", None)) with rc.context(rc_kw, mode=rc_mode): + edgecolor = _not_none( + kwargs.get("color", None), + rc.find("axes.edgecolor", context=True), + rc["axes.edgecolor"], + ) + linewidth = _not_none( + kwargs.get("linewidth", None), + rc.find("axes.linewidth", context=True), + rc["axes.linewidth"], + ) + tickcolor = _not_none( + kwargs.get("tickcolor", None), + kwargs.get("color", None), + rc.find("xtick.color", context=True), + rc["xtick.color"], + ) + tickwidth = _not_none( + kwargs.get("tickwidth", None), + kwargs.get("linewidth", None) and linewidth, + rc.find("tick.width", context=True), + rc["tick.width"], + ) + tickwidthratio = _not_none( + kwargs.get("tickwidthratio", None), + rc.find("tick.widthratio", context=True), + rc["tick.widthratio"], + ) + # Not mutable default args thetalocator_kw = thetalocator_kw or {} thetaminorlocator_kw = thetaminorlocator_kw or {} @@ -320,6 +641,23 @@ def format( if thetadir is not None: self.set_theta_direction(thetadir) + # Polar frame styling used to come from the shared background helper. + # Apply it explicitly now that patch and frame styling are separated. + self._update_frame( + "x", + edgecolor=edgecolor, + linewidth=linewidth, + tickcolor=tickcolor, + tickwidth=tickwidth, + tickwidthratio=tickwidthratio, + ) + self._update_frame( + "y", + tickcolor=tickcolor, + tickwidth=tickwidth, + tickwidthratio=tickwidthratio, + ) + # Loop over axes for ( x, @@ -382,6 +720,23 @@ def format( x, formatter=formatter, formatter_kw=formatter_kw ) + # Polar-aware axis labels (rendered along the arc / radial spoke) + for kind, text, loc, label_kw in ( + ("theta", thetalabel, thetalabelloc, thetalabel_kw), + ("r", rlabel, rlabelloc, rlabel_kw), + ): + kw = dict( + loc=loc, + labelpad=labelpad, + color=labelcolor, + size=labelsize, + weight=labelweight, + ) + if kind == "r": + kw["rlabelpos"] = rlabelpos + kw.update(label_kw or {}) + self._update_polar_label(kind, text, **kw) + # Parent format method super().format(rc_kw=rc_kw, rc_mode=rc_mode, **kwargs) diff --git a/ultraplot/axes/shared.py b/ultraplot/axes/shared.py index 6b66c6219..dadc33b77 100644 --- a/ultraplot/axes/shared.py +++ b/ultraplot/axes/shared.py @@ -40,41 +40,55 @@ def _min_max_lim(key, min_=None, max_=None, lim=None): max_ = _not_none(**{f"{key}max": max_, f"{key}lim_1": lim[1]}) return min_, max_ - def _update_background(self, x=None, tickwidth=None, tickwidthratio=None, **kwargs): + def _update_background(self, **kwargs): """ - Update the background patch and spines. + Update the background patch. """ - # Update the background patch kw_face, kw_edge = rc._get_background_props(**kwargs) self.patch.update(kw_face) - if x is None: - opts = self.spines - elif x == "x": - opts = ("bottom", "top", "inner", "polar") - else: - opts = ("left", "right", "start", "end") - for opt in opts: - self.spines.get(opt, {}).update(kw_edge) + return kw_face, kw_edge - # Update the tick colors - axis = "both" if x is None else x - x = _not_none(x, "x") - obj = getattr(self, x + "axis") - edgecolor = kw_edge.get("edgecolor", None) + def _update_frame( + self, + x, + *, + edgecolor=None, + linewidth=None, + tickcolor=None, + tickwidth=None, + tickwidthratio=None, + ): + """ + Update the axis frame, including spines and tick line appearance. + """ + opts = ( + ("bottom", "top", "inner", "polar") + if x == "x" + else ( + "left", + "right", + "start", + "end", + ) + ) + kw_edge = {"capstyle": "projecting"} if edgecolor is not None: - self.tick_params(axis=axis, which="both", color=edgecolor) + kw_edge["edgecolor"] = edgecolor + if linewidth is not None: + kw_edge["linewidth"] = linewidth + if len(kw_edge) > 1: + for opt in opts: + self.spines.get(opt, {}).update(kw_edge) + + obj = getattr(self, x + "axis") + if tickcolor is None: + tickcolor = edgecolor + if tickcolor is not None: + self.tick_params(axis=x, which="both", color=tickcolor) # Update the tick widths - # NOTE: Only use 'linewidth' if it was explicitly passed. Do not - # include 'linewidth' inferred from rc['axes.linewidth'] setting. kwmajor = getattr(obj, "_major_tick_kw", {}) # graceful fallback if API changes kwminor = getattr(obj, "_minor_tick_kw", {}) - if "linewidth" in kwargs: - tickwidth = _not_none(tickwidth, kwargs["linewidth"]) - tickwidth = _not_none(tickwidth, rc.find("tick.width", context=True)) - tickwidthratio = _not_none( - tickwidthratio, rc.find("tick.widthratio", context=True) - ) # noqa: E501 tickwidth_prev = kwmajor.get("width", rc[x + "tick.major.width"]) if tickwidth_prev == 0: tickwidthratio_prev = rc["tick.widthratio"] # no other way of knowing @@ -92,7 +106,7 @@ def _update_background(self, x=None, tickwidth=None, tickwidthratio=None, **kwar elif which == "minor": tickwidthratio = _not_none(tickwidthratio, tickwidthratio_prev) kwticks["width"] *= tickwidthratio - self.tick_params(axis=axis, which=which, **kwticks) + self.tick_params(axis=x, which=which, **kwticks) def _update_ticks( self, diff --git a/ultraplot/colors.py b/ultraplot/colors.py index becca6851..6d1c18cc4 100644 --- a/ultraplot/colors.py +++ b/ultraplot/colors.py @@ -27,7 +27,6 @@ import matplotlib as mpl import matplotlib.cm as mcm import matplotlib.colors as mcolors -import matplotlib.colors as mcolors import numpy as np import numpy.ma as ma diff --git a/ultraplot/config.py b/ultraplot/config.py index 4ac429af7..b3270a492 100644 --- a/ultraplot/config.py +++ b/ultraplot/config.py @@ -63,6 +63,50 @@ # Constants COLORS_KEEP = ("red", "green", "blue", "cyan", "yellow", "magenta", "white", "black") +_ULTRAPLOT_STYLES = { + "poster": { + "font.size": 14, + "axes.titlesize": 18, + "axes.labelsize": 16, + "xtick.labelsize": 13, + "ytick.labelsize": 13, + "legend.fontsize": 13, + "figure.titlesize": 20, + "lines.linewidth": 2.0, + "lines.markersize": 6, + "figure.facecolor": "none", + "savefig.facecolor": "none", + "savefig.edgecolor": "none", + }, + "dark_background": { + "figure.facecolor": "#000000", + "figure.edgecolor": "#000000", + "axes.facecolor": "#000000", + "axes.edgecolor": "#cbd5e1", + "axes.labelcolor": "#f8fafc", + "text.color": "#f8fafc", + "xtick.color": "#cbd5e1", + "ytick.color": "#cbd5e1", + "grid.color": "#475569", + "grid.alpha": 0.35, + "legend.facecolor": "#000000", + "legend.edgecolor": "#475569", + "savefig.facecolor": "#000000", + "savefig.edgecolor": "#000000", + "axes.prop_cycle": cycler.cycler( + color=( + "#60a5fa", + "#f59e0b", + "#34d399", + "#f472b6", + "#a78bfa", + "#f87171", + ) + ), + }, +} +_ULTRAPLOT_STYLES["dark"] = _ULTRAPLOT_STYLES["dark_background"] + # Configurator docstrings _rc_docstring = """ local : bool, default: True @@ -305,6 +349,7 @@ def _get_style_dict(style, filter=True): # copying the entire rcParams dict we just track the keys that were changed. style_aliases = { "538": "fivethirtyeight", + "dark": "dark_background", "mpl20": "default", "mpl15": "classic", "original": mpl.matplotlib_fname(), @@ -333,7 +378,9 @@ def _get_style_dict(style, filter=True): kw = style elif isinstance(style, str): style = style_aliases.get(style, style) - if style in mstyle.library: + if style in _ULTRAPLOT_STYLES: + kw = _ULTRAPLOT_STYLES[style] + elif style in mstyle.library: kw = mstyle.library[style] else: try: @@ -1249,6 +1296,13 @@ def _get_tickline_props(self, axis=None, which="major", native=True, rebuild=Fal context = not rebuild and (native or self._context_mode == 2) kwticks = self.category(f"{axis}tick.{which}", context=context) kwticks.pop("visible", None) + + # NOTE: We pop visibility properties from the styling dictionary so that + # stylistic updates (like applying a dark_background theme) do not override + # the tick visibility logic strictly managed by ax._update_locs() and alternate axes. + for key in ("bottom", "top", "left", "right"): + kwticks.pop(key, None) + for key in ("color", "direction"): value = self.find(f"{axis}tick.{key}", context=context) if value is not None: @@ -1260,15 +1314,19 @@ def _get_ticklabel_props(self, axis=None, native=True, rebuild=False): Return the tick label properties, optionally filtering the output dictionary based on the context. """ - # NOTE: 'tick.label' properties are now synonyms of 'grid.label' properties + # Geographic gridline labels use the ultraplot-only grid.label* settings, + # while native matplotlib tick labels use x/y tick rcParams. sprefix = axis or "" cprefix = sprefix if _version_mpl >= "3.4" else "" # new settings context = not rebuild and (native or self._context_mode == 2) + color_key = f"{cprefix}tick.labelcolor" if native else "grid.labelcolor" + size_key = f"{sprefix}tick.labelsize" if native else "grid.labelsize" + weight_key = "tick.labelweight" if native else "grid.labelweight" kwtext = self.fill( { - "color": f"{cprefix}tick.labelcolor", # native setting sometimes avail - "size": f"{sprefix}tick.labelsize", # native setting always avail - "weight": "tick.labelweight", # native setting never avail + "color": color_key, # native setting sometimes avail + "size": size_key, + "weight": weight_key, # native setting never avail "family": "font.family", # apply manually }, context=context, diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 334c61a44..028ed3d13 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -29,8 +29,10 @@ from typing_extensions import override from . import axes as paxes +from .axes._formatting import pop_axis_format_kwargs from . import constructor from . import gridspec as pgridspec +from . import legend as plegend from .config import rc, rc_matplotlib from .internals import ( _not_none, @@ -411,6 +413,124 @@ ) # noqa: E501 +# Figure semantic legend helpers +_figure_semantic_legend_common_docstring = """ +**legend_kwargs + Placement and legend styling keywords forwarded to + `~ultraplot.figure.Figure.legend` when ``add=True``. This includes figure legend + placement keywords like ``loc=``, ``ref=``, ``ax=``, ``rows=``, ``cols=``, and + ``span=``. Pass ``add=False`` to return ``(handles, labels)`` without drawing. +""" +docstring._snippet_manager["figure.semantic_legend_common"] = ( + _figure_semantic_legend_common_docstring +) + +_figure_entrylegend_docstring = """ +Build generic semantic legend entries and optionally add a figure legend. + +Parameters +---------- +entries + Entry specifications as handles, style dictionaries, or ``(label, spec)`` + pairs. + +Other parameters +---------------- +%(figure.semantic_legend_common)s + +Notes +----- +Handle generation currently reuses the semantic legend builder used by +`~ultraplot.axes.Axes.entrylegend`, then routes the final draw step through +`~ultraplot.figure.Figure.legend`. +""" +docstring._snippet_manager["figure.entrylegend"] = _figure_entrylegend_docstring + +_figure_catlegend_docstring = """ +Build categorical legend entries and optionally add a figure legend. + +Parameters +---------- +categories + Category labels used to generate legend handles. + +Other parameters +---------------- +%(figure.semantic_legend_common)s + +Notes +----- +Handle generation currently reuses the semantic legend builder used by +`~ultraplot.axes.Axes.catlegend`, then routes the final draw step through +`~ultraplot.figure.Figure.legend`. +""" +docstring._snippet_manager["figure.catlegend"] = _figure_catlegend_docstring + +_figure_sizelegend_docstring = """ +Build size legend entries and optionally add a figure legend. + +Parameters +---------- +levels + Numeric levels used to generate marker-size entries. + +Other parameters +---------------- +%(figure.semantic_legend_common)s + +Notes +----- +Handle generation currently reuses the semantic legend builder used by +`~ultraplot.axes.Axes.sizelegend`, then routes the final draw step through +`~ultraplot.figure.Figure.legend`. + +Pass ``labels=[...]`` or ``labels={level: label}`` to override the generated labels. +""" +docstring._snippet_manager["figure.sizelegend"] = _figure_sizelegend_docstring + +_figure_numlegend_docstring = """ +Build numeric-color legend entries and optionally add a figure legend. + +Parameters +---------- +levels + Numeric levels or number of levels. + +Other parameters +---------------- +%(figure.semantic_legend_common)s + +Notes +----- +Handle generation currently reuses the semantic legend builder used by +`~ultraplot.axes.Axes.numlegend`, then routes the final draw step through +`~ultraplot.figure.Figure.legend`. +""" +docstring._snippet_manager["figure.numlegend"] = _figure_numlegend_docstring + +_figure_geolegend_docstring = """ +Build geometry legend entries and optionally add a figure legend. + +Parameters +---------- +entries + Geometry entries (mapping, ``(label, geometry)`` pairs, or geometries). +labels + Optional labels for geometry sequences. + +Other parameters +---------------- +%(figure.semantic_legend_common)s + +Notes +----- +Handle generation currently reuses the semantic legend builder used by +`~ultraplot.axes.Axes.geolegend`, then routes the final draw step through +`~ultraplot.figure.Figure.legend`. +""" +docstring._snippet_manager["figure.geolegend"] = _figure_geolegend_docstring + + # Save docstring _save_docstring = """ Save the figure. @@ -496,6 +616,12 @@ def _canvas_preprocess(self, *args, **kwargs): skip_autolayout = getattr(fig, "_skip_autolayout", False) layout_dirty = getattr(fig, "_layout_dirty", False) + saving_frame_count = getattr(fig, "_saving_frame_count", 0) + lock_tight_during_save = ( + getattr(self, "_is_saving", False) + and saving_frame_count > 0 + and getattr(fig, "_tight_active", False) + ) if ( skip_autolayout and getattr(fig, "_layout_initialized", False) @@ -514,14 +640,20 @@ def _canvas_preprocess(self, *args, **kwargs): with ctx1, ctx2, ctx3: needs_post_layout = False if not fig._layout_initialized or layout_dirty: - fig.auto_layout() + fig.auto_layout(tight=False if lock_tight_during_save else None) fig._layout_initialized = True fig._layout_dirty = False - needs_post_layout = _needs_post_tight_layout(fig) + needs_post_layout = ( + not lock_tight_during_save and _needs_post_tight_layout(fig) + ) result = func(self, *args, **kwargs) if needs_post_layout: fig.auto_layout() result = func(self, *args, **kwargs) + if method == "print_figure" and getattr(self, "_is_saving", False): + fig._saving_frame_count = saving_frame_count + 1 + elif not getattr(self, "_is_saving", False): + fig._saving_frame_count = 0 return result # Add preprocessor @@ -1325,6 +1457,25 @@ def _share_ticklabels(self, *, axis: str) -> None: # Process each group independently for _, group_axes in groups.items(): + # Singleton groups can still need border masking reapplied for + # supported axes (e.g. GeoAxes split by guides), but unsupported + # singleton groups like a single PolarAxes should not warn. + main_axes = [ + axi for axi in group_axes if not getattr(axi, "_panel_side", None) + ] + supported_main_axes = any( + isinstance( + axi, (paxes.CartesianAxes, paxes._CartopyAxes, paxes._BasemapAxes) + ) + for axi in main_axes + ) + if len(group_axes) < 2 and not supported_main_axes: + continue + if all( + self._effective_share_level(axi, axis, sides) < 3 for axi in group_axes + ): + continue + # Build baseline from MAIN axes only (exclude panels) baseline, skip_group = self._compute_baseline_tick_state( group_axes, axis, label_keys @@ -3052,6 +3203,254 @@ def _update_super_title(self, title, **kwargs): if title is not None: self._suptitle.set_text(title) + @staticmethod + def _iter_semantic_legend_axes(candidate): + """ + Yield axes objects from nested axis containers. + """ + if candidate is None or isinstance(candidate, str): + return + if isinstance(candidate, maxes.Axes): + yield candidate + return + if np.iterable(candidate): + for item in candidate: + yield from Figure._iter_semantic_legend_axes(item) + + def _semantic_legend_axes(self, ax=None, ref=None): + """ + Pick an axes instance for semantic legend handle generation. + """ + for candidate in (ax, ref, self.axes): + for axis in self._iter_semantic_legend_axes(candidate): + return axis + raise RuntimeError( + "Figure semantic legend helpers require an existing axes. " + "Create an axes first or pass ax=... or ref=...." + ) + + @docstring._snippet_manager + def entrylegend( + self, + entries, + *, + line=None, + marker=None, + color=None, + linestyle=None, + linewidth=None, + markersize=None, + alpha=None, + markeredgecolor=None, + markeredgewidth=None, + markerfacecolor=None, + handle_kw=None, + add=True, + **legend_kwargs, + ): + """ + %(figure.entrylegend)s + """ + axes = self._semantic_legend_axes( + ax=legend_kwargs.get("ax"), ref=legend_kwargs.get("ref") + ) + handles, labels = plegend.UltraLegend(axes).entrylegend( + entries, + line=line, + marker=marker, + color=color, + linestyle=linestyle, + linewidth=linewidth, + markersize=markersize, + alpha=alpha, + markeredgecolor=markeredgecolor, + markeredgewidth=markeredgewidth, + markerfacecolor=markerfacecolor, + handle_kw=handle_kw, + add=False, + ) + if not add: + return handles, labels + return self.legend(handles, labels, **legend_kwargs) + + @docstring._snippet_manager + def catlegend( + self, + categories, + *, + colors=None, + markers=None, + line=None, + linestyle=None, + linewidth=None, + markersize=None, + alpha=None, + markeredgecolor=None, + markeredgewidth=None, + markerfacecolor=None, + handle_kw=None, + add=True, + **legend_kwargs, + ): + """ + %(figure.catlegend)s + """ + axes = self._semantic_legend_axes( + ax=legend_kwargs.get("ax"), ref=legend_kwargs.get("ref") + ) + handles, labels = plegend.UltraLegend(axes).catlegend( + categories, + colors=colors, + markers=markers, + line=line, + linestyle=linestyle, + linewidth=linewidth, + markersize=markersize, + alpha=alpha, + markeredgecolor=markeredgecolor, + markeredgewidth=markeredgewidth, + markerfacecolor=markerfacecolor, + handle_kw=handle_kw, + add=False, + ) + if not add: + return handles, labels + return self.legend(handles, labels, **legend_kwargs) + + @docstring._snippet_manager + def sizelegend( + self, + levels, + *, + labels=None, + color=None, + marker=None, + area=None, + scale=None, + minsize=None, + fmt=None, + alpha=None, + markeredgecolor=None, + markeredgewidth=None, + markerfacecolor=None, + handle_kw=None, + add=True, + **legend_kwargs, + ): + """ + %(figure.sizelegend)s + """ + axes = self._semantic_legend_axes( + ax=legend_kwargs.get("ax"), ref=legend_kwargs.get("ref") + ) + handles, labels = plegend.UltraLegend(axes).sizelegend( + levels, + labels=labels, + color=color, + marker=marker, + area=area, + scale=scale, + minsize=minsize, + fmt=fmt, + alpha=alpha, + markeredgecolor=markeredgecolor, + markeredgewidth=markeredgewidth, + markerfacecolor=markerfacecolor, + handle_kw=handle_kw, + add=False, + ) + if not add: + return handles, labels + return self.legend(handles, labels, **legend_kwargs) + + @docstring._snippet_manager + def numlegend( + self, + levels=None, + *, + vmin=None, + vmax=None, + n=None, + cmap=None, + norm=None, + fmt=None, + facecolor=None, + edgecolor=None, + linewidth=None, + linestyle=None, + alpha=None, + handle_kw=None, + add=True, + **legend_kwargs, + ): + """ + %(figure.numlegend)s + """ + axes = self._semantic_legend_axes( + ax=legend_kwargs.get("ax"), ref=legend_kwargs.get("ref") + ) + handles, labels = plegend.UltraLegend(axes).numlegend( + levels=levels, + vmin=vmin, + vmax=vmax, + n=n, + cmap=cmap, + norm=norm, + fmt=fmt, + facecolor=facecolor, + edgecolor=edgecolor, + linewidth=linewidth, + linestyle=linestyle, + alpha=alpha, + handle_kw=handle_kw, + add=False, + ) + if not add: + return handles, labels + return self.legend(handles, labels, **legend_kwargs) + + @docstring._snippet_manager + def geolegend( + self, + entries, + labels=None, + *, + country_reso=None, + country_territories=None, + country_proj=None, + handlesize=None, + facecolor=None, + edgecolor=None, + linewidth=None, + alpha=None, + fill=None, + add=True, + **legend_kwargs, + ): + """ + %(figure.geolegend)s + """ + axes = self._semantic_legend_axes( + ax=legend_kwargs.get("ax"), ref=legend_kwargs.get("ref") + ) + handles, labels = plegend.UltraLegend(axes).geolegend( + entries, + labels=labels, + country_reso=country_reso, + country_territories=country_territories, + country_proj=country_proj, + handlesize=handlesize, + facecolor=facecolor, + edgecolor=edgecolor, + linewidth=linewidth, + alpha=alpha, + fill=fill, + add=False, + ) + if not add: + return handles, labels + return self.legend(handles, labels, **legend_kwargs) + @_clear_border_cache @docstring._concatenate_inherited @docstring._snippet_manager @@ -3225,7 +3624,14 @@ def format( # Initiate context block axs = axs or self._subplot_dict.values() skip_axes = kwargs.pop("skip_axes", False) # internal keyword arg + explicit_format_keys = set(kwargs) + signature_axis_kwargs, generic_axis_kwargs = pop_axis_format_kwargs( + kwargs, *paxes.Axes._format_signatures.values() + ) + explicit_format_keys.update(signature_axis_kwargs) + explicit_format_keys.update(generic_axis_kwargs) rc_kw, rc_mode = _pop_rc(kwargs) + kwargs.update(signature_axis_kwargs) with rc.context(rc_kw, mode=rc_mode): # Update background patch kw = rc.fill({"facecolor": "figure.facecolor"}, context=True) @@ -3277,11 +3683,15 @@ def format( if skip_axes: # avoid recursion return - # Remove all keywords that are not in the allowed signature parameters + # Collect each class's matching kwargs without popping, then drop the union — + # shared params (e.g. xlabel/ylabel, accepted by both CartesianAxes and + # PolarAxes) need to reach every matching class. kws = { - cls: _pop_params(kwargs, sig) + cls: {k: kwargs[k] for k in sig.parameters if kwargs.get(k) is not None} for cls, sig in paxes.Axes._format_signatures.items() } + for k in {k for cls_kw in kws.values() for k in cls_kw}: + kwargs.pop(k, None) classes = set() # track used dictionaries def _axis_has_share_label_text(ax, axis): @@ -3312,13 +3722,27 @@ def _axis_has_label_text(ax, axis): if kw.get("ylabel") is not None and self._has_share_label_groups("y"): if _axis_has_share_label_text(ax, "y") or _axis_has_label_text(ax, "y"): kw.pop("ylabel", None) - ax.format(rc_kw=rc_kw, rc_mode=rc_mode, skip_figure=True, **kw, **kwargs) + explicit_kw = {} + if isinstance(ax, paxes.CartesianAxes): + explicit_kw["_explicit_format_keys"] = explicit_format_keys + ax.format( + rc_kw=rc_kw, + rc_mode=rc_mode, + skip_figure=True, + **explicit_kw, + **kw, + **kwargs, + **generic_axis_kwargs, + ) ax.number = store_old_number - # Warn unused keyword argument(s) + # Warn unused keyword argument(s). Shared params (those in multiple + # signatures) are considered "used" if any matched class consumed them. + used_keys = {k for cls in classes for k in kws[cls]} kw = { key: value for name in kws.keys() - classes for key, value in kws[name].items() + if key not in used_keys } if kw: warnings._warn_ultraplot( diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 2f4761f23..29ee8c5bd 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -17,6 +17,7 @@ import numpy as np from . import axes as paxes +from .axes._formatting import pop_axis_format_kwargs from .config import rc from .internals import ( _not_none, @@ -2079,6 +2080,20 @@ def format(self, **kwargs): ultraplot.figure.Figure.format ultraplot.config.Configurator.context """ + + def _supports_implicit_label_share(target): + compatible_sides = { + "x": {"top", "bottom"}, + "y": {"left", "right"}, + } + for ax in axes: + side = getattr(ax, "_panel_side", None) + if side is None: + continue + if side not in compatible_sides[target]: + return False + return True + # Implicit label sharing for subset format calls share_xlabels = kwargs.get("share_xlabels", None) share_ylabels = kwargs.get("share_ylabels", None) @@ -2100,8 +2115,25 @@ def format(self, **kwargs): else: shared_title_loc = None shared_title_pad = None + signature_axis_kwargs, generic_axis_kwargs = pop_axis_format_kwargs( + kwargs, *paxes.Axes._format_signatures.values() + ) rc_kw, rc_mode = _pop_rc(kwargs) + kwargs.update(signature_axis_kwargs) + kwargs.update(generic_axis_kwargs) with rc.context(rc_kw, mode=rc_mode): + implicit_share_xlabels = ( + is_subset + and share_xlabels is None + and xlabel is not None + and _supports_implicit_label_share("x") + ) + implicit_share_ylabels = ( + is_subset + and share_ylabels is None + and ylabel is not None + and _supports_implicit_label_share("y") + ) if len(self) > 1: if share_xlabels is False: self.figure._clear_share_label_groups(self, target="x") @@ -2111,9 +2143,9 @@ def format(self, **kwargs): self.figure._clear_share_label_groups(self, target="x") if not is_subset and share_ylabels is None and ylabel is not None: self.figure._clear_share_label_groups(self, target="y") - if is_subset and share_xlabels is None and xlabel is not None: + if implicit_share_xlabels: self.figure._register_share_label_group(self, target="x") - if is_subset and share_ylabels is None and ylabel is not None: + if implicit_share_ylabels: self.figure._register_share_label_group(self, target="y") self.figure.format(axs=self, **kwargs) if shared_subset_title: @@ -2126,9 +2158,9 @@ def format(self, **kwargs): ) # Refresh groups after labels are set if len(self) > 1: - if is_subset and share_xlabels is None and xlabel is not None: + if implicit_share_xlabels: self.figure._register_share_label_group(self, target="x") - if is_subset and share_ylabels is None and ylabel is not None: + if implicit_share_ylabels: self.figure._register_share_label_group(self, target="y") def share_labels(self, *, axis="x"): diff --git a/ultraplot/internals/__init__.py b/ultraplot/internals/__init__.py index 487f73c60..16bdd4501 100644 --- a/ultraplot/internals/__init__.py +++ b/ultraplot/internals/__init__.py @@ -340,6 +340,7 @@ def _pop_rc(src, *, ignore_conflicts=True): "tight", "span", ) + kw = src.pop("rc_kw", None) or {} if "mode" in src: src["rc_mode"] = src.pop("mode") diff --git a/ultraplot/internals/docstring.py b/ultraplot/internals/docstring.py index 39b2938f6..13bdff625 100644 --- a/ultraplot/internals/docstring.py +++ b/ultraplot/internals/docstring.py @@ -121,6 +121,7 @@ class _SnippetManager(dict): "plot": "ultraplot.axes.plot", "figure": "ultraplot.figure", "gridspec": "ultraplot.gridspec", + "legend": "ultraplot.legend", "ticker": "ultraplot.ticker", "proj": "ultraplot.proj", "colors": "ultraplot.colors", diff --git a/ultraplot/internals/inputs.py b/ultraplot/internals/inputs.py index 0f8ac4e46..14438c686 100644 --- a/ultraplot/internals/inputs.py +++ b/ultraplot/internals/inputs.py @@ -686,7 +686,7 @@ def _meta_coords(*args, which="x", **kwargs): if data.ndim > 1: raise ValueError("Non-1D string coordinate input is unsupported.") ticks = np.arange(len(data)) - labels = list(map(str, data)) + labels = list(map(str, _to_numpy_array(data))) kwargs.setdefault(which + "locator", Locator(ticks)) kwargs.setdefault(which + "formatter", Formatter(labels, index=True)) kwargs.setdefault(which + "minorlocator", Locator("null")) diff --git a/ultraplot/internals/labels.py b/ultraplot/internals/labels.py index 9cb49d2ec..51c0e32c5 100644 --- a/ultraplot/internals/labels.py +++ b/ultraplot/internals/labels.py @@ -11,6 +11,34 @@ from . import ic # noqa: F401 +# Pseudo-properties handled by `_update_label`. These are not valid +# `matplotlib.text.Text` constructor kwargs, so they must be filtered before +# instantiating Text and re-applied via `_update_label` afterwards. +LABEL_PSEUDO_PROPS = frozenset( + { + "border", + "bordercolor", + "borderinvert", + "borderwidth", + "borderstyle", + "bbox", + "bboxcolor", + "bboxstyle", + "bboxalpha", + "bboxpad", + } +) + + +def _split_label_props(kwargs): + """ + Split a kwargs dict into (label_props, text_kwargs) so the latter can be + passed to `mtext.Text(...)` and the former applied via `_update_label`. + """ + label_props = {k: kwargs[k] for k in kwargs if k in LABEL_PSEUDO_PROPS} + text_kwargs = {k: v for k, v in kwargs.items() if k not in LABEL_PSEUDO_PROPS} + return label_props, text_kwargs + def merge_font_properties( dest_fp: FontProperties, src_fp: FontProperties diff --git a/ultraplot/legend.py b/ultraplot/legend.py index c8c5c579d..8f1096cb0 100644 --- a/ultraplot/legend.py +++ b/ultraplot/legend.py @@ -9,12 +9,14 @@ import numpy as np from matplotlib import cm as mcm from matplotlib import colors as mcolors +from matplotlib.colors import is_color_like as _mpl_is_color_like from matplotlib import lines as mlines from matplotlib import legend as mlegend from matplotlib import legend_handler as mhandler +from matplotlib.markers import MarkerStyle from .config import rc -from .internals import _not_none, _pop_props, guides, rcsetup +from .internals import _not_none, _pop_props, docstring, guides, rcsetup from .utils import _fontsize_to_pt, units try: @@ -91,8 +93,33 @@ def __init__( markeredgecolor=None, markeredgewidth=None, alpha=None, + marker_capstyle=None, + marker_joinstyle=None, + marker_transform=None, **kwargs, ): + # ``Line2D`` exposes capstyle/joinstyle/transform/fillstyle only via the + # marker object, not as kwargs. Wrap the marker spec in a ``MarkerStyle`` + # so these properties survive into the rendered legend entry. Pop + # ``fillstyle`` from kwargs first so it doesn't reach ``Line2D.__init__`` + # twice when ``MarkerStyle`` consumes it. + fillstyle = kwargs.pop("fillstyle", None) + if ( + marker_capstyle is not None + or marker_joinstyle is not None + or marker_transform is not None + or fillstyle is not None + ) and not isinstance(marker, MarkerStyle): + marker_kw = {} + if marker_capstyle is not None: + marker_kw["capstyle"] = marker_capstyle + if marker_joinstyle is not None: + marker_kw["joinstyle"] = marker_joinstyle + if marker_transform is not None: + marker_kw["transform"] = marker_transform + if fillstyle is not None: + marker_kw["fillstyle"] = fillstyle + marker = MarkerStyle(marker, **marker_kw) marker = "o" if marker is None and not line else marker linestyle = "none" if not line else linestyle if markerfacecolor is None and color is not None: @@ -138,11 +165,14 @@ def marker(cls, label=None, marker="o", **kwargs): "pentagon": mpath.Path.unit_regular_polygon(5), "hexagon": mpath.Path.unit_regular_polygon(6), "star": mpath.Path.unit_regular_star(5), + "rectangle": mpath.Path( + [[0, 0], [2, 0], [2, 1], [0, 1], [0, 0]], closed=True, readonly=True + ), } _GEOMETRY_SHAPE_ALIASES = { "box": "square", - "rect": "square", - "rectangle": "square", + "rect": "rectangle", + "rec": "rectangle", "tri": "triangle", "pent": "pentagon", "hex": "hexagon", @@ -772,11 +802,7 @@ def _geo_legend_entries( country_reso: str = "110m", country_territories: bool = False, country_proj: Any = None, - facecolor: Any = "none", - edgecolor: Any = "0.25", - linewidth: float = 1.0, - alpha: Optional[float] = None, - fill: Optional[bool] = None, + patch_kw: dict = None, ): """ Build geometry semantic legend handles and labels. @@ -834,29 +860,125 @@ def _geo_legend_entries( "Labels and geometry entries must have the same length. " f"Got {len(label_list)} labels and {len(geometry_list)} entries." ) + if patch_kw is None: + patch_kw = {} + facecolor = patch_kw.get("facecolor", "none") + edgecolor = patch_kw.get("edgecolor", "0.25") + linewidth = patch_kw.get("linewidth", 1.0) + alpha = patch_kw.get("alpha", None) + fill = patch_kw.get("fill", None) + handles = [] - for geometry, label, options in zip(geometry_list, label_list, entry_options): + for idx, (geometry, label, options) in enumerate( + zip(geometry_list, label_list, entry_options) + ): + # Resolve per-entry values (scalar → all; list → cycled; dict → matched by label) + fc = _style_lookup(facecolor, label, idx, default="none", prop="facecolor") + ec = _style_lookup(edgecolor, label, idx, default="0.25", prop="edgecolor") + lw = _style_lookup(linewidth, label, idx, default=1.0, prop=None) + a = _style_lookup(alpha, label, idx, default=None, prop=None) + fl = _style_lookup(fill, label, idx, default=None, prop=None) + geo_kwargs = { "country_reso": country_reso, "country_territories": country_territories, "country_proj": country_proj, - "facecolor": facecolor, - "edgecolor": edgecolor, - "linewidth": linewidth, - "alpha": alpha, - "fill": fill, + "facecolor": fc, + "edgecolor": ec, + "linewidth": lw, + "alpha": a, + "fill": fl, } + # Apply any remaining patch properties (hatch, linestyle, capstyle, etc.) + for k, v in patch_kw.items(): + if k not in geo_kwargs: + geo_kwargs[k] = _style_lookup(v, label, idx, default=None, prop=k) geo_kwargs.update(options or {}) handles.append(GeometryEntry(geometry, label=label, **geo_kwargs)) + return handles, label_list -def _style_lookup(style, key, index, default=None): +# _is_color_like should only check the following args +_COLOR_KEYS = { + "color", + "facecolor", + "edgecolor", + "markerfacecolor", + "markeredgecolor", + "markerfacecoloralt", +} + + +def _is_color_like(value): + """ + Determine whether a value can be interpreted as a single color. + + A tuple or list of 3 or 4 numbers in ``[0, 1]`` is treated as one RGB(A) + color rather than a per-entry style sequence — matching matplotlib's + color parser and giving tuple/list symmetric behavior. Other lists fall + through to per-entry resolution by ``_style_lookup``. + """ + if value is None: + return False + if isinstance(value, (tuple, list)): + if len(value) in (3, 4) and all( + isinstance(v, (int, float)) and 0.0 <= v <= 1.0 for v in value + ): + return True + return False + return _mpl_is_color_like(value) + + +# Line2D / LegendEntry alias mapping. ``ec`` / ``fc`` are deliberately +# omitted: they already resolve to ``markeredgecolor`` / ``markerfacecolor`` +# via ultraplot's internal ``_pop_props(kwargs, "line")``. +_LINE_ALIAS_MAP = { + "c": "color", + "m": "marker", + "ms": "markersize", + "ls": "linestyle", + "lw": "linewidth", + "mec": "markeredgecolor", + "mew": "markeredgewidth", + "mfc": "markerfacecolor", + "mfcalt": "markerfacecoloralt", + "aa": "antialiased", + "fs": "fillstyle", +} + +# Patch alias mapping +_PATCH_ALIAS_MAP = { + "c": "color", + "fc": "facecolor", + "ec": "edgecolor", + "ls": "linestyle", + "lw": "linewidth", + "aa": "antialiased", +} + + +def _style_lookup(style, key, index, default=None, *, prop=None): """ - Resolve style values from scalar, mapping, or sequence inputs. + Resolve a style value from scalar, mapping, or sequence inputs. + + Parameters + ---------- + style : the style value (scalar, list, dict) + key : dict key when `style` is a mapping (typically a label) + index : list index when `style` is a sequence + default : fallback value + prop : optional attribute name; if it belongs to _COLOR_KEYS, + the function treats color-like sequences as single colors. """ if style is None: return default + + # Only perform color detection for known color properties + check_color = prop is not None and prop in _COLOR_KEYS + + if check_color and _is_color_like(style): + return style if isinstance(style, dict): return style.get(key, default) if isinstance(style, str): @@ -901,24 +1023,90 @@ def _default_cycle_colors(): "linestyles": "linestyle", "linewidths": "markeredgewidth", "sizes": "markersize", + "size": "markersize", } +def _pop_aliases(kwargs: dict[str, Any], alias_map: dict[str, str]) -> dict[str, Any]: + """Pop short aliases (``c``, ``ls``, …) from ``kwargs`` mapped to full names.""" + resolved = {} + for alias in list(kwargs): + if alias in alias_map: + resolved[alias_map[alias]] = kwargs.pop(alias) + return resolved + + +def _pop_plurals(kwargs: dict[str, Any], plural_map: dict[str, str]) -> dict[str, Any]: + """Pop collection-style plurals (``colors``, ``sizes``, …) from ``kwargs``.""" + explicit = {} + for key in plural_map: + if key in kwargs: + explicit[key] = kwargs.pop(key) + return explicit + + +def _pop_line2d_setters(kwargs: dict[str, Any]) -> dict[str, Any]: + """ + Pop remaining kwargs that correspond to ``Line2D`` setters. + + Catches properties that ``_pop_props(..., "line")`` does not know about + (e.g. ``fillstyle``, ``solid_capstyle``) so they survive into the + ``LegendEntry`` constructor instead of leaking through to ``Axes.legend``, + where matplotlib rejects them. + + ``label``/``labels`` look like Line2D setters but are intentionally not + consumed here — the semantic-legend validator (covered by + ``test_semantic_legend_rejects_label{,s}_kwarg``) needs them to surface + as ``TypeError`` from the public ``legend()`` call. + """ + extracted = {} + for key in list(kwargs): + if key in ("labels", "label") or key.startswith("_"): + continue + if hasattr(mlines.Line2D, "set_" + key): + extracted[key] = kwargs.pop(key) + return extracted + + def _pop_entry_props(kwargs: dict[str, Any]) -> dict[str, Any]: """ - Pop style properties with line/scatter aliases for LegendEntry objects. + Extract ``LegendEntry`` style properties from ``kwargs``. + + Resolution order (highest → lowest priority): + + 1. Full-name properties recognised by ``_pop_props(kwargs, "line")``. + 2. Collection-style plurals (``colors`` → ``color``, ``sizes`` → ``markersize``, …). + 3. Short aliases (``c`` → ``color``, ``ls`` → ``linestyle``, …). + 4. Any other valid ``Line2D`` setter still in ``kwargs``. + + Advanced ``MarkerStyle`` properties (``marker_capstyle``/``_joinstyle``/ + ``_transform``) are pulled out first so ``_pop_props`` does not consume + them, and merged back at the end with full priority. """ - explicit_collection = {} - for key in _ENTRY_STYLE_FROM_COLLECTION: + advanced_marker = {} + for key in ("marker_capstyle", "marker_joinstyle", "marker_transform"): if key in kwargs: - explicit_collection[key] = kwargs.pop(key) + advanced_marker[key] = kwargs.pop(key) + + resolved_aliases = _pop_aliases(kwargs, _LINE_ALIAS_MAP) + explicit_collection = _pop_plurals(kwargs, _ENTRY_STYLE_FROM_COLLECTION) + props = _pop_props(kwargs, "line") collection_props = _pop_props(kwargs, "collection") collection_props.update(explicit_collection) + for source, target in _ENTRY_STYLE_FROM_COLLECTION.items(): value = collection_props.get(source, None) if value is not None and target not in props: props[target] = value + + for full_key, value in resolved_aliases.items(): + props.setdefault(full_key, value) + + for full_key, value in _pop_line2d_setters(kwargs).items(): + props.setdefault(full_key, value) + + props.update(advanced_marker) return props @@ -933,19 +1121,24 @@ def _pop_entry_props(kwargs: dict[str, Any]) -> dict[str, Any]: def _pop_num_props(kwargs: dict[str, Any]) -> dict[str, Any]: """ - Pop patch/collection style aliases for numeric semantic legend entries. + Extract patch-style properties (and collection-plural / short aliases) for + numeric semantic legend entries (``numlegend`` / ``geolegend``). """ - explicit_collection = {} - for key in _NUM_STYLE_FROM_COLLECTION: - if key in kwargs: - explicit_collection[key] = kwargs.pop(key) + resolved_aliases = _pop_aliases(kwargs, _PATCH_ALIAS_MAP) + explicit_collection = _pop_plurals(kwargs, _NUM_STYLE_FROM_COLLECTION) + props = _pop_props(kwargs, "patch") collection_props = _pop_props(kwargs, "collection") collection_props.update(explicit_collection) + for source, target in _NUM_STYLE_FROM_COLLECTION.items(): value = collection_props.get(source, None) if value is not None and target not in props: props[target] = value + + for full_key, value in resolved_aliases.items(): + props.setdefault(full_key, value) + return props @@ -959,17 +1152,17 @@ def _resolve_style_values( """ output = {} for key, value in styles.items(): - resolved = _style_lookup(value, label, index, default=None) + resolved = _style_lookup(value, label, index, default=None, prop=key) if resolved is not None: output[key] = resolved return output def _cat_legend_entries( - categories: Iterable[Any], + categories, *, - colors=None, - markers="o", + color=None, + marker="o", line=False, linestyle="-", linewidth=2.0, @@ -999,18 +1192,27 @@ def _cat_legend_entries( handles = [] for idx, label in enumerate(labels): styles = _resolve_style_values(base_styles, label, idx) - color = _style_lookup(colors, label, idx, default=palette[idx % len(palette)]) - marker = _style_lookup(markers, label, idx, default="o") line_value = bool(styles.pop("line", False)) - if line_value and marker in (None, ""): - marker = None - styles.pop("marker", None) + linestyle_value = styles.pop("linestyle", "-") + marker_value = styles.pop("marker", None) + + # If line=False but user provides a non-default linestyle, automatically enable line=True + if not line_value and linestyle_value not in (None, "-", "none", "None"): + line_value = True + + color_val = _style_lookup( + color, label, idx, default=palette[idx % len(palette)], prop="color" + ) + marker_val = _style_lookup(marker, label, idx, default="o", prop="marker") + if line_value and marker_val in (None, ""): + marker_val = None handles.append( LegendEntry( label=str(label), - color=color, + color=color_val, line=line_value, - marker=marker, + marker=marker_val, + linestyle=linestyle_value, **styles, ) ) @@ -1169,8 +1371,12 @@ def _size_legend_entries( handles = [] for idx, (value, label, size) in enumerate(zip(values, label_list, ms)): styles = _resolve_style_values(base_styles, float(value), idx) - color_value = _style_lookup(color, float(value), idx, default="0.35") - marker_value = _style_lookup(marker, float(value), idx, default="o") + color_value = _style_lookup( + color, float(value), idx, default="0.35", prop="color" + ) + marker_value = _style_lookup( + marker, float(value), idx, default="o", prop="marker" + ) line_value = bool(styles.pop("line", False)) if line_value and marker_value in ("", None): marker_value = None @@ -1394,6 +1600,81 @@ def _normalize_em_kwargs(kwargs: dict[str, Any], *, fontsize: float) -> dict[str return kwargs +_semantic_style_arg_docstring = """\ +A style value resolved per legend entry. Accepts a **scalar** (applied + to every entry), a **list / tuple / ndarray** (one value per entry, + cycled to match the number of entries), or a **dict** (mapping from + label — or from numeric value for ``sizelegend`` / ``numlegend`` — to + style; missing keys fall back to the default). A 3- or 4-element + sequence of floats in ``[0, 1]`` is treated as a single RGB(A) color + rather than as per-entry values, so ``color=[0.5, 0.5, 0.5]`` and + ``color=(0.5, 0.5, 0.5)`` behave the same.""" + +_semantic_style_kwargs_docstring = """\ +Common style keywords accepted via ``handle_kw`` or ``**kwargs``: + +``color`` / ``c`` + Marker (and line, when ``line=True``) color. ``c`` is the short alias. +``marker`` / ``m`` + Marker spec. Set to ``None`` or ``""`` to suppress the marker. +``markersize`` / ``ms``, ``markeredgewidth`` / ``mew`` + Marker dimensions. +``markerfacecolor`` / ``mfc``, ``markeredgecolor`` / ``mec``, ``markerfacecoloralt`` / ``mfcalt`` + Marker fills and edges. +``linestyle`` / ``ls``, ``linewidth`` / ``lw`` + Connector line styling. Setting a non-default ``linestyle`` implicitly + enables ``line=True``. +``alpha``, ``antialiased`` / ``aa``, ``fillstyle`` / ``fs`` + Generic appearance. +``marker_capstyle``, ``marker_joinstyle``, ``marker_transform`` + Advanced ``MarkerStyle`` properties; wrapped into the rendered marker. + +Plural forms (``colors``, ``markers``, ``sizes``, ``edgecolors``, +``facecolors``, ``linestyles``, ``linewidths``) are accepted as +synonyms for the singular per-entry form for backward compatibility. +Each value accepts the scalar / sequence / mapping forms described in +``%(legend.semantic_style_arg)s``.""" + +_semantic_num_style_kwargs_docstring = """\ +Patch-style keywords accepted via ``handle_kw`` or ``**kwargs``: + +``facecolor`` / ``fc``, ``edgecolor`` / ``ec``, ``color`` / ``c`` + Patch fills and edges. +``linewidth`` / ``lw``, ``linestyle`` / ``ls`` + Patch outline styling. +``alpha``, ``antialiased`` / ``aa``, ``hatch``, ``fill``, +``joinstyle``, ``capstyle`` + Generic patch appearance. + +Plural collection forms (``colors``, ``facecolors``, ``edgecolors``, +``linestyles``, ``linewidths``) map to the singular per-entry form. +Each value accepts the scalar / sequence / mapping forms described in +``%(legend.semantic_style_arg)s``.""" + +_semantic_handle_kw_docstring = """\ +handle_kw : dict, optional + Style overrides applied to each generated handle. Same vocabulary as + ``**kwargs``; useful when style kwargs would otherwise collide with + matplotlib's :class:`~matplotlib.legend.Legend` keywords (``loc``, + ``title``, …). +add : bool, default: True + When ``True`` (default), draw the legend on the axes and return the + legend artist. When ``False``, return ``(handles, labels)`` without + drawing — useful for composing into a parent legend. +**kwargs + Style keywords applied per entry (see above), plus any + :class:`~matplotlib.legend.Legend` keyword.""" + +docstring._snippet_manager["legend.semantic_style_arg"] = _semantic_style_arg_docstring +docstring._snippet_manager["legend.semantic_style_kwargs"] = ( + _semantic_style_kwargs_docstring +) +docstring._snippet_manager["legend.semantic_num_style_kwargs"] = ( + _semantic_num_style_kwargs_docstring +) +docstring._snippet_manager["legend.semantic_handle_kw"] = _semantic_handle_kw_docstring + + class UltraLegend: """ Centralized legend builder for axes. @@ -1425,55 +1706,36 @@ def entrylegend( line: Optional[bool] = None, marker=None, color=None, - linestyle=None, - linewidth: Optional[float] = None, - markersize: Optional[float] = None, - alpha=None, - markeredgecolor=None, - markeredgewidth=None, - markerfacecolor=None, handle_kw: Optional[dict[str, Any]] = None, add: bool = True, - **legend_kwargs: Any, + **kwargs: Any, ): """ Build generic semantic legend entries and optionally draw a legend. + Public docs live on :meth:`Axes.entrylegend`. """ - styles = dict(handle_kw or {}) - styles.update(_pop_entry_props(styles)) + styles = {} + if handle_kw: + styles.update(_pop_entry_props(handle_kw)) + styles.update(_pop_entry_props(kwargs)) + line = _not_none(line, styles.pop("line", None), rc["legend.cat.line"]) marker = _not_none(marker, styles.pop("marker", None), rc["legend.cat.marker"]) color = _not_none(color, styles.pop("color", None)) - linestyle = _not_none( - linestyle, - styles.pop("linestyle", None), - rc["legend.cat.linestyle"], - ) - linewidth = _not_none( - linewidth, - styles.pop("linewidth", None), - rc["legend.cat.linewidth"], - ) + linestyle = _not_none(styles.pop("linestyle", None), rc["legend.cat.linestyle"]) + linewidth = _not_none(styles.pop("linewidth", None), rc["legend.cat.linewidth"]) markersize = _not_none( - markersize, - styles.pop("markersize", None), - rc["legend.cat.markersize"], + styles.pop("markersize", None), rc["legend.cat.markersize"] ) - alpha = _not_none(alpha, styles.pop("alpha", None), rc["legend.cat.alpha"]) + alpha = _not_none(styles.pop("alpha", None), rc["legend.cat.alpha"]) markeredgecolor = _not_none( - markeredgecolor, - styles.pop("markeredgecolor", None), - rc["legend.cat.markeredgecolor"], + styles.pop("markeredgecolor", None), rc["legend.cat.markeredgecolor"] ) markeredgewidth = _not_none( - markeredgewidth, - styles.pop("markeredgewidth", None), - rc["legend.cat.markeredgewidth"], - ) - markerfacecolor = _not_none( - markerfacecolor, - styles.pop("markerfacecolor", None), + styles.pop("markeredgewidth", None), rc["legend.cat.markeredgewidth"] ) + markerfacecolor = _not_none(styles.pop("markerfacecolor", None), None) + handles, labels = _entry_legend_entries( entries, line=line, @@ -1490,71 +1752,52 @@ def entrylegend( ) if not add: return handles, labels - self._validate_semantic_kwargs("entrylegend", legend_kwargs) - return self.axes.legend(handles, labels, **legend_kwargs) + self._validate_semantic_kwargs("entrylegend", kwargs) + return self.axes.legend(handles, labels, **kwargs) def catlegend( self, categories: Iterable[Any], *, - colors=None, - markers=None, + color=None, + marker=None, line: Optional[bool] = None, - linestyle=None, - linewidth: Optional[float] = None, - markersize: Optional[float] = None, - alpha=None, - markeredgecolor=None, - markeredgewidth=None, - markerfacecolor=None, handle_kw: Optional[dict[str, Any]] = None, add: bool = True, - **legend_kwargs: Any, + **kwargs: Any, ): """ Build categorical legend entries and optionally draw a legend. + Public docs live on :meth:`Axes.catlegend`. """ - styles = dict(handle_kw or {}) - styles.update(_pop_entry_props(styles)) + styles = {} + if handle_kw: + styles.update(_pop_entry_props(handle_kw)) + styles.update(_pop_entry_props(kwargs)) + line = _not_none(line, styles.pop("line", None), rc["legend.cat.line"]) - colors = _not_none(colors, styles.pop("color", None)) - markers = _not_none( - markers, styles.pop("marker", None), rc["legend.cat.marker"] - ) - linestyle = _not_none( - linestyle, - styles.pop("linestyle", None), - rc["legend.cat.linestyle"], - ) - linewidth = _not_none( - linewidth, - styles.pop("linewidth", None), - rc["legend.cat.linewidth"], - ) + color = _not_none(color, styles.pop("color", None)) + marker = _not_none(marker, styles.pop("marker", None), rc["legend.cat.marker"]) + linestyle = _not_none(styles.pop("linestyle", None), rc["legend.cat.linestyle"]) + linewidth = _not_none(styles.pop("linewidth", None), rc["legend.cat.linewidth"]) markersize = _not_none( - markersize, - styles.pop("markersize", None), - rc["legend.cat.markersize"], + styles.pop("markersize", None), rc["legend.cat.markersize"] ) - alpha = _not_none(alpha, styles.pop("alpha", None), rc["legend.cat.alpha"]) + alpha = _not_none(styles.pop("alpha", None), rc["legend.cat.alpha"]) markeredgecolor = _not_none( - markeredgecolor, - styles.pop("markeredgecolor", None), - rc["legend.cat.markeredgecolor"], + styles.pop("markeredgecolor", None), rc["legend.cat.markeredgecolor"] ) markeredgewidth = _not_none( - markeredgewidth, - styles.pop("markeredgewidth", None), - rc["legend.cat.markeredgewidth"], - ) - markerfacecolor = _not_none( - markerfacecolor, - styles.pop("markerfacecolor", None), + styles.pop("markeredgewidth", None), rc["legend.cat.markeredgewidth"] ) + markerfacecolor = _not_none(styles.pop("markerfacecolor", None), None) + + # Remaining styles are passed as additional entry properties + # (e.g., 'markerfacecoloralt') to _cat_legend_entries handles, labels = _cat_legend_entries( categories, - colors=colors, - markers=markers, + color=color, + marker=marker, line=line, linestyle=linestyle, linewidth=linewidth, @@ -1567,10 +1810,8 @@ def catlegend( ) if not add: return handles, labels - self._validate_semantic_kwargs("catlegend", legend_kwargs) - # Route through Axes.legend so location shorthands (e.g. 'r', 'b') - # and queued guide keyword handling behave exactly like the public API. - return self.axes.legend(handles, labels, **legend_kwargs) + self._validate_semantic_kwargs("catlegend", kwargs) + return self.axes.legend(handles, labels, **kwargs) def sizelegend( self, @@ -1583,40 +1824,32 @@ def sizelegend( scale: Optional[float] = None, minsize: Optional[float] = None, fmt=None, - alpha=None, - markeredgecolor=None, - markeredgewidth=None, - markerfacecolor=None, handle_kw: Optional[dict[str, Any]] = None, add: bool = True, - **legend_kwargs: Any, + **kwargs: Any, ): """ Build size legend entries and optionally draw a legend. + Public docs live on :meth:`Axes.sizelegend`. """ - styles = dict(handle_kw or {}) - styles.update(_pop_entry_props(styles)) + styles = {} + if handle_kw: + styles.update(_pop_entry_props(handle_kw)) + styles.update(_pop_entry_props(kwargs)) color = _not_none(color, styles.pop("color", None), rc["legend.size.color"]) marker = _not_none(marker, styles.pop("marker", None), rc["legend.size.marker"]) area = _not_none(area, rc["legend.size.area"]) scale = _not_none(scale, rc["legend.size.scale"]) minsize = _not_none(minsize, rc["legend.size.minsize"]) fmt = _not_none(fmt, rc["legend.size.format"]) - alpha = _not_none(alpha, styles.pop("alpha", None), rc["legend.size.alpha"]) + alpha = _not_none(styles.pop("alpha", None), rc["legend.size.alpha"]) markeredgecolor = _not_none( - markeredgecolor, - styles.pop("markeredgecolor", None), - rc["legend.size.markeredgecolor"], + styles.pop("markeredgecolor", None), rc["legend.size.markeredgecolor"] ) markeredgewidth = _not_none( - markeredgewidth, - styles.pop("markeredgewidth", None), - rc["legend.size.markeredgewidth"], - ) - markerfacecolor = _not_none( - markerfacecolor, - styles.pop("markerfacecolor", None), + styles.pop("markeredgewidth", None), rc["legend.size.markeredgewidth"] ) + markerfacecolor = _not_none(styles.pop("markerfacecolor", None), None) handles, labels = _size_legend_entries( levels, labels=labels, @@ -1634,8 +1867,8 @@ def sizelegend( ) if not add: return handles, labels - self._validate_semantic_kwargs("sizelegend", legend_kwargs) - return self.axes.legend(handles, labels, **legend_kwargs) + self._validate_semantic_kwargs("sizelegend", kwargs) + return self.axes.legend(handles, labels, **kwargs) def numlegend( self, @@ -1654,30 +1887,32 @@ def numlegend( alpha=None, handle_kw: Optional[dict[str, Any]] = None, add: bool = True, - **legend_kwargs: Any, + **kwargs: Any, ): """ Build numeric-color legend entries and optionally draw a legend. + Public docs live on :meth:`Axes.numlegend`. """ - styles = dict(handle_kw or {}) - styles.update(_pop_num_props(styles)) + styles = {} + if handle_kw: + styles.update(_pop_num_props(handle_kw)) + styles.update(_pop_num_props(kwargs)) + color = styles.pop("color", None) n = _not_none(n, rc["legend.num.n"]) cmap = _not_none(cmap, rc["legend.num.cmap"]) facecolor = _not_none(facecolor, styles.pop("facecolor", None), color) edgecolor = _not_none( - edgecolor, - styles.pop("edgecolor", None), - rc["legend.num.edgecolor"], + edgecolor, styles.pop("edgecolor", None), rc["legend.num.edgecolor"] ) linewidth = _not_none( - linewidth, - styles.pop("linewidth", None), - rc["legend.num.linewidth"], + linewidth, styles.pop("linewidth", None), rc["legend.num.linewidth"] ) linestyle = _not_none(linestyle, styles.pop("linestyle", None)) alpha = _not_none(alpha, styles.pop("alpha", None), rc["legend.num.alpha"]) fmt = _not_none(fmt, rc["legend.num.format"]) + + # Remaining styles (e.g. hatch, joinstyle, capstyle, fill) pass through. handles, labels = _num_legend_entries( levels=levels, vmin=vmin, @@ -1693,10 +1928,11 @@ def numlegend( alpha=alpha, **styles, ) + if not add: return handles, labels - self._validate_semantic_kwargs("numlegend", legend_kwargs) - return self.axes.legend(handles, labels, **legend_kwargs) + self._validate_semantic_kwargs("numlegend", kwargs) + return self.axes.legend(handles, labels, **kwargs) def geolegend( self, @@ -1712,55 +1948,69 @@ def geolegend( linewidth: Optional[float] = None, alpha: Optional[float] = None, fill: Optional[bool] = None, + handle_kw: Optional[dict[str, Any]] = None, add: bool = True, - **legend_kwargs: Any, + **kwargs: Any, ): """ Build geometry legend entries and optionally draw a legend. - - Notes - ----- - Geometry legend entries use normalized patch proxies inside the legend - handle box rather than reusing the original map artist directly. This - preserves the general geometry shape and copied patch styling, but very - small or high-aspect-ratio handles can still make hatches difficult to - read at legend scale. + Public docs live on :meth:`Axes.geolegend`. """ - facecolor = _not_none(facecolor, rc["legend.geo.facecolor"]) - edgecolor = _not_none(edgecolor, rc["legend.geo.edgecolor"]) - linewidth = _not_none(linewidth, rc["legend.geo.linewidth"]) - alpha = _not_none(alpha, rc["legend.geo.alpha"]) - fill = _not_none(fill, rc["legend.geo.fill"]) + styles = {} + if handle_kw: + styles.update(_pop_num_props(handle_kw)) + styles.update(_pop_num_props(kwargs)) + + facecolor = _not_none( + facecolor, styles.pop("facecolor", None), rc["legend.geo.facecolor"] + ) + edgecolor = _not_none( + edgecolor, styles.pop("edgecolor", None), rc["legend.geo.edgecolor"] + ) + linewidth = _not_none( + linewidth, styles.pop("linewidth", None), rc["legend.geo.linewidth"] + ) + alpha = _not_none(alpha, styles.pop("alpha", None), rc["legend.geo.alpha"]) + fill = _not_none(fill, styles.pop("fill", None), rc["legend.geo.fill"]) + + # Carry remaining styles (hatch, linestyle, joinstyle, …) through + # to per-entry resolution in ``_geo_legend_entries``. + patch_kw = { + "facecolor": facecolor, + "edgecolor": edgecolor, + "linewidth": linewidth, + "alpha": alpha, + "fill": fill, + } + patch_kw.update(styles) + country_reso = _not_none(country_reso, rc["legend.geo.country_reso"]) country_territories = _not_none( country_territories, rc["legend.geo.country_territories"] ) country_proj = _not_none(country_proj, rc["legend.geo.country_proj"]) handlesize = _not_none(handlesize, rc["legend.geo.handlesize"]) + handles, labels = _geo_legend_entries( entries, labels=labels, country_reso=country_reso, country_territories=country_territories, country_proj=country_proj, - facecolor=facecolor, - edgecolor=edgecolor, - linewidth=linewidth, - alpha=alpha, - fill=fill, + patch_kw=patch_kw, ) if not add: return handles, labels - self._validate_semantic_kwargs("geolegend", legend_kwargs) + self._validate_semantic_kwargs("geolegend", kwargs) if handlesize is not None: handlesize = float(handlesize) if handlesize <= 0: raise ValueError("geolegend handlesize must be positive.") - if "handlelength" not in legend_kwargs: - legend_kwargs["handlelength"] = rc["legend.handlelength"] * handlesize - if "handleheight" not in legend_kwargs: - legend_kwargs["handleheight"] = rc["legend.handleheight"] * handlesize - return self.axes.legend(handles, labels, **legend_kwargs) + if "handlelength" not in kwargs: + kwargs["handlelength"] = rc["legend.handlelength"] * handlesize + if "handleheight" not in kwargs: + kwargs["handleheight"] = rc["legend.handleheight"] * handlesize + return self.axes.legend(handles, labels, **kwargs) @staticmethod def _align_map() -> dict[Optional[str], dict[str, str]]: diff --git a/ultraplot/tests/test_1dplots.py b/ultraplot/tests/test_1dplots.py index d63256a52..d57309161 100644 --- a/ultraplot/tests/test_1dplots.py +++ b/ultraplot/tests/test_1dplots.py @@ -138,6 +138,24 @@ def test_bar_width(rng): return fig +def test_bar_scalar_bottom(): + """ + Regression for #731: pandas dispatches Series.plot(kind="barh") via + matplotlib with a scalar ``bottom`` (or ``left``), which previously hit + ``AttributeError: 'int' object has no attribute 'size'`` inside + ``_inbounds_xylim``. + """ + # Direct scalar `bottom` / `left` + fig, ax = uplt.subplots() + ax.bar([1, 2, 3], [4, 5, 6], bottom=0) + ax.barh([1, 2, 3], [4, 5, 6], left=0) + + # The original failing reproducer from the issue + series = pd.Series({"a": 1, "b": 2, "c": 3}) + fig, ax = uplt.subplots() + series.plot(kind="barh", ax=ax[0]) + + @pytest.mark.mpl_image_compare def test_bar_vectors(): """ diff --git a/ultraplot/tests/test_animation.py b/ultraplot/tests/test_animation.py index 6e8ad2efc..00ee7c007 100644 --- a/ultraplot/tests/test_animation.py +++ b/ultraplot/tests/test_animation.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock +import matplotlib import numpy as np import pytest from matplotlib.animation import FuncAnimation @@ -58,3 +59,51 @@ def update(frame): ani = FuncAnimation(fig, update, frames=10) # The test passes if no exception is raised fig.canvas.draw() + + +def test_animation_save_only_tightens_first_frame(tmp_path): + """ + Saving an animation should not rerun tight layout on every frame after the + first saved frame, or frame geometry can shift between outputs. + """ + matplotlib.use("Agg") + state = np.random.RandomState(51423) + + fig, axs = uplt.subplots(nrows=1, ncols=2, width="14cm") + mappables = [] + for ax in axs: + m = ax.heatmap(state.rand(10, 10), cmap="dusk") + ax.colorbar(m, loc="t", tickdir="out", label="Axes Colorbars") + mappables.append(m) + + axs.format( + abc="(a)", + abcloc="ul", + xlabel="xlabel", + ylabel="ylabel", + toplabels=("Left Axes", "Right Axes"), + urtitle="1", + suptitle="Test Animation", + ) + + auto_layout_calls = [] + original_auto_layout = fig.auto_layout + + def wrapped_auto_layout(*args, **kwargs): + auto_layout_calls.append(kwargs.get("tight", None)) + return original_auto_layout(*args, **kwargs) + + fig.auto_layout = wrapped_auto_layout + + def update(frame): + for m in mappables: + m.set_array(state.rand(10, 10)) + axs.format(urtitle=f"{frame + 1}") + return mappables + + ani = FuncAnimation(fig, update, frames=3, interval=150) + ani.save(tmp_path / "test_animation.gif", writer="pillow") + + assert auto_layout_calls + assert auto_layout_calls[0] is not False + assert auto_layout_calls[1:] == [False] * (len(auto_layout_calls) - 1) diff --git a/ultraplot/tests/test_axes.py b/ultraplot/tests/test_axes.py index 2c432b515..c90474142 100644 --- a/ultraplot/tests/test_axes.py +++ b/ultraplot/tests/test_axes.py @@ -342,16 +342,18 @@ def test_title_manual_size_ignores_auto_shrink(): assert title_obj.get_fontsize() == 20 -def test_title_shrinks_when_abc_overlaps_different_loc(): +def test_title_shifts_when_abc_overlaps_different_loc(): """ - Ensure long titles shrink when overlapping abc at a different location. + Ensure centered titles keep their requested size by shifting away from abc. """ fig, axs = uplt.subplots(figsize=(3, 2)) - axs.format(abc=True, title="X" * 200, titleloc="center", abcloc="left") + axs.format(abc=True, title="X" * 100, titleloc="center", abcloc="left") title_obj = axs[0]._title_dict["center"] original_size = title_obj.get_fontsize() + original_x = title_obj.get_position()[0] fig.canvas.draw() - assert title_obj.get_fontsize() < original_size + assert title_obj.get_fontsize() == original_size + assert title_obj.get_position()[0] > original_x def test_title_shrinks_right_aligned_same_location(): @@ -418,17 +420,18 @@ def test_title_no_shrink_when_no_overlap(): assert title_obj.get_fontsize() == original_size -def test_title_shrinks_centered_left_of_abc(): +def test_title_shifts_centered_left_of_abc(): """ - Test that centered titles shrink when they are to the left of abc label. - This covers the specific case where base_x <= ax0 - pad for centered titles. + Centered titles should also shift left to avoid a right-side abc label. """ fig, axs = uplt.subplots(figsize=(3, 2)) axs.format(abc=True, title="X" * 100, titleloc="center", abcloc="right") title_obj = axs[0]._title_dict["center"] original_size = title_obj.get_fontsize() + original_x = title_obj.get_position()[0] fig.canvas.draw() - assert title_obj.get_fontsize() < original_size + assert title_obj.get_fontsize() == original_size + assert title_obj.get_position()[0] < original_x ticks = axs[0].get_xticks() assert ticks.size > 0 xy = np.column_stack([ticks, np.zeros_like(ticks)]) diff --git a/ultraplot/tests/test_axes_alt_styles.py b/ultraplot/tests/test_axes_alt_styles.py new file mode 100644 index 000000000..492872db2 --- /dev/null +++ b/ultraplot/tests/test_axes_alt_styles.py @@ -0,0 +1,256 @@ +import matplotlib.colors as mcolors +import pytest +import ultraplot as uplt + + +def _all_match_color(colors, expected): + expected = mcolors.to_rgba(expected) + return all(mcolors.to_rgba(color) == expected for color in colors) + + +def test_alt_axes_styling_dark_background(): + """ + Test that applying dark_background style does not leak tick visibility + settings and correctly preserves alternative axes tick locations. + """ + with uplt.rc.context(style="dark_background"): + fig, ax = uplt.subplots() + ax.format(ycolor="C0", ylabel="Left Axis") + + ax2 = ax.alty(color="C1") + ax2.format(ycolor="C1", ylabel="Right Axis", ylim=(0, 1)) + + # The left axis should ONLY have visible ticks on the left + left_ax_left_ticks = sum( + 1 + for t in ax.yaxis.get_ticklines() + if t.get_visible() and t.get_xdata()[0] == 0 + ) + left_ax_right_ticks = sum( + 1 + for t in ax.yaxis.get_ticklines() + if t.get_visible() and t.get_xdata()[0] == 1 + ) + + # The right axis (ax2) should ONLY have visible ticks on the right + right_ax_left_ticks = sum( + 1 + for t in ax2.yaxis.get_ticklines() + if t.get_visible() and t.get_xdata()[0] == 0 + ) + right_ax_right_ticks = sum( + 1 + for t in ax2.yaxis.get_ticklines() + if t.get_visible() and t.get_xdata()[0] == 1 + ) + + assert left_ax_left_ticks > 0, "Left axis should have left ticks" + assert left_ax_right_ticks == 0, "Left axis should NOT have right ticks" + + assert right_ax_left_ticks == 0, "Right axis should NOT have left ticks" + assert right_ax_right_ticks > 0, "Right axis should have right ticks" + + assert _all_match_color( + [ + line.get_color() + for line in ax2.yaxis.get_ticklines() + if line.get_visible() + ], + "C1", + ) + assert { + mcolors.to_rgba(ax2.spines[side].get_edgecolor()) + for side in ("left", "right") + if ax2.spines[side].get_visible() + } == {mcolors.to_rgba("C1")} + + +@pytest.mark.parametrize( + ("setup", "format_kwargs", "expected_color", "expected_linewidth"), + [ + ( + lambda ax: ax, + {"ycolor": "C0", "ylinewidth": 3, "ylabel": "Left Axis"}, + "C0", + 3, + ), + ( + lambda ax: ax.alty(color="C1", linewidth=3), + {"ylabel": "Right Axis", "ylim": (0, 1)}, + "C1", + 3, + ), + ], +) +def test_dark_background_preserves_axis_colors_on_reformat( + setup, format_kwargs, expected_color, expected_linewidth +): + with uplt.rc.context(style="dark_background"): + fig, ax = uplt.subplots() + target = setup(ax) + target.format(**format_kwargs) + target.format(ylabel="Updated Label") + + assert _all_match_color( + [label.get_color() for label in target.get_yticklabels()], expected_color + ) + assert mcolors.to_rgba(target.yaxis.label.get_color()) == mcolors.to_rgba( + expected_color + ) + assert _all_match_color( + [ + line.get_color() + for line in target.yaxis.get_ticklines() + if line.get_visible() + ], + expected_color, + ) + assert { + mcolors.to_rgba(target.spines[side].get_edgecolor()) + for side in ("left", "right") + if target.spines[side].get_visible() + } == {mcolors.to_rgba(expected_color)} + assert { + target.spines[side].get_linewidth() + for side in ("left", "right") + if target.spines[side].get_visible() + } == {expected_linewidth} + + +def test_dark_background_updates_unspecified_axis_frame_style(): + fig, ax = uplt.subplots() + + with uplt.rc.context(style="dark_background"): + ax.format(ylabel="Updated Label") + + expected = mcolors.to_rgba(uplt.rc["axes.edgecolor"]) + assert { + mcolors.to_rgba(ax.spines[side].get_edgecolor()) + for side in ("left", "right") + if ax.spines[side].get_visible() + } == {expected} + assert _all_match_color( + [ + line.get_color() + for line in ax.yaxis.get_ticklines() + if line.get_visible() + ], + expected, + ) + + +@pytest.mark.parametrize( + ("format_kwargs", "getter", "expected_color"), + [ + ( + {"ytickcolor": "red"}, + lambda ax: [ + line.get_color() + for line in ax.yaxis.get_ticklines() + if line.get_visible() + ], + "red", + ), + ( + {"yticklabelcolor": "blue"}, + lambda ax: [label.get_color() for label in ax.get_yticklabels()], + "blue", + ), + ( + {"ylabelcolor": "green"}, + lambda ax: [ax.yaxis.label.get_color()], + "green", + ), + ], +) +def test_subplots_preserve_explicit_axis_property_overrides_on_reformat( + format_kwargs, getter, expected_color +): + with uplt.rc.context(style="dark_background"): + fig, axs = uplt.subplots() + ax = axs[0] + axs.format(**format_kwargs) + axs.format(ylabel="Updated Label") + + assert _all_match_color(getter(ax), expected_color) + + +def test_subplots_preserve_generic_tickcolor_across_later_axis_color(): + with uplt.rc.context(style="dark_background"): + fig, axs = uplt.subplots() + ax = axs[0] + axs.format(tickcolor="red") + axs.format(ycolor="C1") + + assert _all_match_color( + [ + line.get_color() + for line in ax.yaxis.get_ticklines() + if line.get_visible() + ], + "red", + ) + assert { + mcolors.to_rgba(ax.spines[side].get_edgecolor()) + for side in ("left", "right") + if ax.spines[side].get_visible() + } == {mcolors.to_rgba("C1")} + + +def test_subplots_apply_generic_labelcolor(): + fig, axs = uplt.subplots() + ax = axs[0] + + axs.format(labelcolor="green") + + assert _all_match_color( + [ax.xaxis.label.get_color(), ax.yaxis.label.get_color()], "green" + ) + + +@pytest.mark.parametrize("format_kwargs", [{"ytickcolor": "red"}, {"tickcolor": "red"}]) +def test_subplots_can_clear_explicit_tickcolor_override(format_kwargs): + with uplt.rc.context(style="dark_background"): + fig, axs = uplt.subplots() + ax = axs[0] + axs.format(**format_kwargs) + clear_kwargs = {key: None for key in format_kwargs} + axs.format(**clear_kwargs) + + assert _all_match_color( + [ + line.get_color() + for line in ax.yaxis.get_ticklines() + if line.get_visible() + ], + uplt.rc["ytick.color"], + ) + + +@pytest.mark.parametrize("format_kwargs", [{"ytickcolor": "red"}, {"tickcolor": "red"}]) +def test_direct_axes_can_clear_explicit_tickcolor_override(format_kwargs): + with uplt.rc.context(style="dark_background"): + fig = uplt.figure() + ax = fig.subplot(111) + ax.format(**format_kwargs) + clear_kwargs = {key: None for key in format_kwargs} + ax.format(ylabel="Updated Label", **clear_kwargs) + + assert _all_match_color( + [ + line.get_color() + for line in ax.yaxis.get_ticklines() + if line.get_visible() + ], + uplt.rc["ytick.color"], + ) + + +def test_polar_format_updates_frame_style(): + fig = uplt.figure() + ax = fig.subplot(111, proj="polar") + + ax.format(color="C3", linewidth=3) + + assert mcolors.to_rgba(ax.spines["polar"].get_edgecolor()) == mcolors.to_rgba("C3") + assert ax.spines["polar"].get_linewidth() == 3 diff --git a/ultraplot/tests/test_config_helpers_extra.py b/ultraplot/tests/test_config_helpers_extra.py index eef60df0b..0a76eb713 100644 --- a/ultraplot/tests/test_config_helpers_extra.py +++ b/ultraplot/tests/test_config_helpers_extra.py @@ -25,6 +25,17 @@ def test_style_dict_and_inference_helpers(): inline_style = config._get_style_dict({"axes.facecolor": "black"}) assert inline_style["axes.facecolor"] == "black" + poster_style = config._get_style_dict("poster") + assert poster_style["figure.facecolor"] == "none" + assert poster_style["font.size"] > config._get_style_dict("default")["font.size"] + + dark_style = config._get_style_dict("dark_background") + assert dark_style["axes.facecolor"] == "#000000" + assert dark_style["text.color"] == "#f8fafc" + + dark_alias_style = config._get_style_dict("dark") + assert dark_alias_style["axes.facecolor"] == dark_style["axes.facecolor"] + combined = {"xtick.labelsize": 9, "axes.titlesize": 14, "text.color": "red"} inferred = config._infer_ultraplot_dict(combined) assert inferred["tick.labelsize"] == 9 @@ -75,6 +86,10 @@ def test_configurator_validation_item_dicts_and_context(tmp_path): assert kw_ultraplot["title.size"] == pytest.approx(14) assert kw_ultraplot["grid.labelcolor"] == "red" + kw_ultraplot, kw_matplotlib = cfg._get_item_dicts("style", "dark_background") + assert kw_matplotlib["axes.facecolor"] == "#000000" + assert kw_ultraplot["grid.labelcolor"] == "#f8fafc" + kw_ultraplot, kw_matplotlib = cfg._get_item_dicts("font.size", 12) assert "abc.size" in kw_ultraplot assert kw_matplotlib["font.size"] == 12 diff --git a/ultraplot/tests/test_docs_search.py b/ultraplot/tests/test_docs_search.py new file mode 100644 index 000000000..023640bf5 --- /dev/null +++ b/ultraplot/tests/test_docs_search.py @@ -0,0 +1,119 @@ +import json +import shutil +import subprocess +import textwrap +from pathlib import Path + +import pytest + +ROOT = Path(__file__).resolve().parents[2] + + +def test_docs_search_prioritizes_api_references_for_generic_function_queries(): + if shutil.which("node") is None: + pytest.skip("Node.js is required to exercise the docs search JavaScript.") + + script = textwrap.dedent(r""" + const fs = require("fs"); + const vm = require("vm"); + const listeners = []; + const classList = { + add() {}, + remove() {}, + contains() { return false; }, + toggle() {}, + }; + const context = { + console, + Scorer: {}, + Search: { + _parseQuery(query) { + const objectTerms = new Set(query.toLowerCase().split(/\s+/).filter(Boolean)); + return [query, new Set(), new Set(), new Set(), objectTerms]; + }, + performObjectSearch(_object, objectTerms) { + this.lastObjectTerms = Array.from(objectTerms); + return this.lastObjectTerms; + }, + }, + }; + context.window = { + innerWidth: 1024, + location: { hash: "", pathname: "/search.html" }, + requestAnimationFrame(callback) { callback(); }, + scrollY: 0, + addEventListener() {}, + }; + context.document = { + body: { + classList, + dataset: {}, + appendChild() {}, + getAttribute() { return ""; }, + }, + documentElement: { classList, dataset: {} }, + addEventListener(type, callback) { + if (type === "DOMContentLoaded") listeners.push(callback); + }, + querySelector() { return null; }, + querySelectorAll() { return []; }, + }; + context.localStorage = { + getItem() { return null; }, + setItem() {}, + }; + + vm.runInNewContext(fs.readFileSync("docs/_static/custom.js", "utf8"), context); + for (const callback of listeners) callback.call(context.document); + + const parsed = context.Search._parseQuery("format function"); + const filteredTerms = context.Search.performObjectSearch("format", parsed[4]); + const apiObjectScore = context.Scorer.score([ + "api/ultraplot.axes.Axes", + "ultraplot.axes.Axes.format", + "#ultraplot.axes.Axes.format", + "Python method, in Axes", + 16, + "api/ultraplot.axes.Axes.html", + "object", + ]); + const apiTextScore = context.Scorer.score([ + "api/ultraplot.axes.Axes", + "Axes", + "", + null, + 16, + "api/ultraplot.axes.Axes.html", + "text", + ]); + const guideScore = context.Scorer.score([ + "basics", + "The basics", + "", + null, + 16, + "basics.html", + "text", + ]); + + console.log(JSON.stringify({ + apiLikeQuery: context.Search.upltApiLikeQuery, + apiObjectScore, + apiTextScore, + filteredTerms, + guideScore, + })); + """) + + result = subprocess.run( + ["node", "-e", script], + cwd=ROOT, + text=True, + capture_output=True, + check=True, + ) + data = json.loads(result.stdout) + + assert data["apiLikeQuery"] is True + assert data["filteredTerms"] == ["format"] + assert data["apiObjectScore"] > data["apiTextScore"] > data["guideScore"] diff --git a/ultraplot/tests/test_figure.py b/ultraplot/tests/test_figure.py index a78ac0402..6ae668c4f 100644 --- a/ultraplot/tests/test_figure.py +++ b/ultraplot/tests/test_figure.py @@ -354,6 +354,56 @@ def test_explicit_share_warns_for_mixed_cartesian_polar(): assert len(incompatible) == 1 +def test_share_zero_polar_emits_no_warnings(recwarn): + fig, axs = uplt.subplots(proj="polar", ncols=2, nrows=3, share=0) + fig.canvas.draw() + + ultra = [ + w + for w in recwarn + if issubclass(w.category, uplt.internals.warnings.UltraPlotWarning) + ] + assert ultra == [], [str(w.message) for w in ultra] + + +def test_share_zero_mixed_cartesian_polar_emits_no_warnings(recwarn): + fig, axs = uplt.subplots(ncols=2, proj=("cart", "polar"), share=0) + fig.canvas.draw() + + ultra = [ + w + for w in recwarn + if issubclass(w.category, uplt.internals.warnings.UltraPlotWarning) + ] + assert ultra == [], [str(w.message) for w in ultra] + + +def test_share_default_single_polar_emits_no_warnings(recwarn): + """A single polar axis has nothing to share — must not warn at default share.""" + fig, ax = uplt.subplots(proj="polar") + fig.canvas.draw() + + ultra = [ + w + for w in recwarn + if issubclass(w.category, uplt.internals.warnings.UltraPlotWarning) + ] + assert ultra == [], [str(w.message) for w in ultra] + + +def test_share_default_single_polar_subplot_singular_emits_no_warnings(recwarn): + """``uplt.subplot(proj='polar')`` (singular) has nothing to share either.""" + fig, ax = uplt.subplot(proj="polar") + fig.canvas.draw() + + ultra = [ + w + for w in recwarn + if issubclass(w.category, uplt.internals.warnings.UltraPlotWarning) + ] + assert ultra == [], [str(w.message) for w in ultra] + + def test_auto_share_local_yscale_change_splits_group(): fig, axs = uplt.subplots(ncols=2, share="auto") fig.canvas.draw() diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 1d917f201..7a3c27844 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -188,6 +188,35 @@ def test_geoticks_label_shorthand_lb_no_warning(recwarn): uplt.close(fig) +def test_geo_labelsize_updates_gridliner_labels(): + fig, ax = uplt.subplots(proj="cyl") + ax = ax[0] + ax.format(labels=True, lonlines=30, latlines=30, labelsize=30) + fig.canvas.draw() + + labels = ( + ax.gridlines_major.bottom_label_artists + ax.gridlines_major.left_label_artists + ) + assert labels + assert {label.get_fontsize() for label in labels} == {30} + uplt.close(fig) + + +def test_subplotgrid_geo_labelsize_updates_gridliner_labels(): + fig, ax = uplt.subplots(proj="cyl") + ax.format(labels=True, lonlines=30, latlines=30, labelsize=30) + fig.canvas.draw() + + geo = ax[0] + labels = ( + geo.gridlines_major.bottom_label_artists + + geo.gridlines_major.left_label_artists + ) + assert labels + assert {label.get_fontsize() for label in labels} == {30} + uplt.close(fig) + + def test_toggle_ticks_supports_bool_and_sequence_specs(): fig, ax = uplt.subplots(proj="cyl") geo = ax[0] diff --git a/ultraplot/tests/test_inputs_helpers.py b/ultraplot/tests/test_inputs_helpers.py index 2288e07a0..299275672 100644 --- a/ultraplot/tests/test_inputs_helpers.py +++ b/ultraplot/tests/test_inputs_helpers.py @@ -169,6 +169,28 @@ def test_mask_range_and_metadata_helpers(): assert inputs._meta_units(np.array([1, 2, 3])) is None +def test_meta_coords_xarray_string_coord(): + """ + Regression test: passing an xarray.DataArray with a string coordinate + to _meta_coords must yield plain string tick labels, not the multi-line + repr of each scalar DataArray element. + """ + xr = pytest.importorskip("xarray") + + da = xr.DataArray( + np.array(["a", "b", "c"]), + coords={"ens": ["a", "b", "c"]}, + dims=["ens"], + name="ens", + ) + + coords, kwargs = inputs._meta_coords(da, which="x") + + assert np.array_equal(coords, np.array([0, 1, 2])) + formatter = kwargs["xformatter"] + assert [formatter(i) for i in coords] == ["a", "b", "c"] + + def test_geographic_helpers_cover_clipping_bounds_and_globes(): clipped = inputs._geo_clip(np.array([-100.0, 0.0, 100.0])) assert np.allclose(clipped, [-90.0, 0.0, 90.0]) diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index 43fae2a81..ab7cc2fba 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -1,12 +1,16 @@ import numpy as np import pandas as pd import pytest +from matplotlib import collections as mcollections from matplotlib import colors as mcolors +from matplotlib import container as mcontainer from matplotlib import legend_handler as mhandler +from matplotlib import lines as mlines from matplotlib import patches as mpatches import ultraplot as uplt from ultraplot.axes import Axes as UAxes +from ultraplot.internals import guides @pytest.mark.mpl_image_compare @@ -183,6 +187,28 @@ def test_tuple_handles(rng): return fig +@pytest.mark.parametrize("kwarg", ["barstd", "boxstd"]) +def test_mean_errorbar_handles_are_preserved_in_legends(kwarg, rng): + fig, axs = uplt.subplots() + ax = axs[0] + data = rng.random((10, 4)).cumsum(axis=0) + + handles = ax.plot(data, means=True, label="label", **{kwarg: 1}) + handles, labels = ax._parse_legend_group(handles, None) + + assert labels == ["label"] + assert len(handles) == 1 + assert isinstance(handles[0], tuple) + assert any(isinstance(obj, mcontainer.ErrorbarContainer) for obj in handles[0]) + + leg = ax.legend(handles) + legend_children = list(guides._iter_children(leg._legend_handle_box)) + assert any(isinstance(obj, mcollections.LineCollection) for obj in legend_children) + assert any(isinstance(obj, mlines.Line2D) for obj in legend_children) + + uplt.close(fig) + + @pytest.mark.mpl_image_compare def test_legend_col_spacing(rng): """ @@ -634,6 +660,77 @@ def test_semantic_legend_rejects_labels_kwarg(builder, args, kwargs): uplt.close(fig) +@pytest.mark.parametrize( + "builder, args, kwargs", + ( + ( + "entrylegend", + ([{"label": "Trend", "line": True}, {"label": "Samples", "line": False}],), + {}, + ), + ("catlegend", (["A", "B"],), {"colors": ["red7", "blue7"]}), + ( + "sizelegend", + ([10, 50],), + {"labels": ["small", "large"], "color": "gray6"}, + ), + ("numlegend", tuple(), {"levels": [0, 1], "cmap": "viridis"}), + ( + "geolegend", + ([("Triangle", "triangle"), ("Hex", "hexagon")],), + {}, + ), + ), +) +def test_figure_semantic_legend_helpers(builder, args, kwargs): + fig, axs = uplt.subplots(ncols=2) + ax = axs[0] + figure_method = getattr(fig, builder) + axes_method = getattr(ax, builder) + + expected_handles, expected_labels = axes_method(*args, add=False, **kwargs) + leg = figure_method(*args, ref=axs, loc="bottom", title=builder, **kwargs) + + assert leg is not None + assert [text.get_text() for text in leg.get_texts()] == expected_labels + assert leg.get_title().get_text() == builder + assert len(leg.legend_handles) == len(expected_handles) + uplt.close(fig) + + +@pytest.mark.parametrize( + "builder, args, kwargs", + ( + ("entrylegend", ([{"label": "Trend", "line": True}],), {}), + ("catlegend", (["A", "B"],), {}), + ("sizelegend", ([10, 50],), {"labels": ["small", "large"]}), + ("numlegend", tuple(), {"levels": [0, 1]}), + ("geolegend", (["triangle"], ["Triangle"]), {}), + ), +) +def test_figure_semantic_legend_add_false_matches_axes(builder, args, kwargs): + fig, ax = uplt.subplots() + figure_method = getattr(fig, builder) + axes_method = getattr(ax, builder) + + fig_handles, fig_labels = figure_method(*args, add=False, **kwargs) + ax_handles, ax_labels = axes_method(*args, add=False, **kwargs) + + assert fig_labels == ax_labels + assert len(fig_handles) == len(ax_handles) + assert [handle.get_label() for handle in fig_handles] == [ + handle.get_label() for handle in ax_handles + ] + uplt.close(fig) + + +def test_figure_semantic_legend_without_axes_raises(): + fig = uplt.figure() + with pytest.raises(RuntimeError, match="require an existing axes"): + fig.catlegend(["A"], loc="right") + uplt.close(fig) + + def test_geo_legend_handlesize_scales_handle_box(): fig, ax = uplt.subplots() leg = ax.geolegend([("shape", "triangle")], loc="best", handlesize=2.0) diff --git a/ultraplot/tests/test_projections.py b/ultraplot/tests/test_projections.py index a52d11318..b32d1e289 100644 --- a/ultraplot/tests/test_projections.py +++ b/ultraplot/tests/test_projections.py @@ -3,12 +3,15 @@ Test projection features. """ +import warnings + import cartopy.crs as ccrs import matplotlib.pyplot as plt -import numpy as np, warnings -import ultraplot as uplt +import numpy as np import pytest +import ultraplot as uplt + @pytest.mark.mpl_image_compare def test_aspect_ratios(): @@ -154,6 +157,327 @@ def test_polar_projections(): return fig +def test_polar_format_thetalabel_rlabel(): + """ + `thetalabel` and `rlabel` both create CurvedText artists. + `thetalabel` follows the outer arc at r=rmax. + `rlabel` follows the radial spoke at rlabel_position, spanning rmin→rmax. + """ + fig, axs = uplt.subplots(proj="polar") + ax = axs[0] + ax.format( + thetalim=(0, 90), + rlim=(0, 1), + thetalabel="thetalabel", + rlabel="rlabel", + ) + assert ax._thetalabel_artist is not None + assert ax._rlabel_artist is not None + assert ax._thetalabel_artist.get_text() == "thetalabel" + assert ax._rlabel_artist.get_text() == "rlabel" + # thetalabel: CurvedText arc at r >= rmax (offset for tick clearance), + # centered on midpoint, 80% of span. + tx, ty = ax._thetalabel_artist.get_curve() + assert np.allclose(ty, ty[0]) + assert ty[0] >= ax.get_rmax() + mid = 0.5 * (0.0 + 90.0) + half_span = 0.5 * 90.0 * 0.8 + assert np.isclose(np.rad2deg(tx[0]), mid - half_span) + assert np.isclose(np.rad2deg(tx[-1]), mid + half_span) + # rlabel: CurvedText along spoke at thetamin (sector default), rmin→rmax + rx, ry = ax._rlabel_artist.get_curve() + assert np.allclose(np.rad2deg(rx), 0.0) # thetamin for (0, 90) sector + assert np.isclose(ry[0], ax.get_rmin()) + assert np.isclose(ry[-1], ax.get_rmax()) + + +def test_polar_format_thetalabel_full_circle(): + """`thetalabel` on a full-range polar axes centers on theta=0.""" + fig, axs = uplt.subplots(proj="polar") + ax = axs[0] + ax.format(thetalabel="thetalabel") + tx, _ = ax._thetalabel_artist.get_curve() + mid_deg = np.rad2deg(0.5 * (tx[0] + tx[-1])) + assert np.isclose(mid_deg % 360, 0.0) + + +def test_polar_format_thetalabel_clear(): + """Passing thetalabel='' clears an existing label.""" + fig, axs = uplt.subplots(proj="polar") + ax = axs[0] + ax.format(thetalabel="x") + ax.format(thetalabel="") + assert ax._thetalabel_artist.get_text() == "" + + +def test_polar_format_thetalabelloc(): + """`thetalabelloc=` overrides the default midpoint center.""" + fig, axs = uplt.subplots(proj="polar") + ax = axs[0] + ax.format(thetalim=(0, 90), thetalabel="thetalabel", thetalabelloc=30) + tx, _ = ax._thetalabel_artist.get_curve() + mid_deg = np.rad2deg(0.5 * (tx[0] + tx[-1])) + assert np.isclose(mid_deg, 30.0) + + +def test_polar_thetalabel_stays_radially_outside_under_theta_transform(): + """The thetalabel offset must stay radially outward after theta transforms.""" + fig, axs = uplt.subplots(ncols=2, proj="polar") + for ax, kwargs in zip( + axs, + ({}, {"theta0": "N", "thetadir": -1}), + ): + ax.format( + thetalim=(0, 180), + rlim=(0.2, 1), + thetalabel="thetalabel", + thetalabelloc=135, + **kwargs, + ) + fig.canvas.draw() + for ax in axs: + tx, ty = ax._thetalabel_artist.get_curve() + idx = len(tx) // 2 + base = ax.transData.transform((tx[idx], ax.get_rmax())) + disp = ax._thetalabel_artist.get_transform().transform((tx[idx], ty[idx])) + outward = ax.transData.transform((tx[idx], ax.get_rmax() + 1.0)) - base + offset = disp - base + assert np.dot(offset, outward) > 0 + + +def test_polar_annular_labels_draw_without_nan_positions(): + """Annular polar labels must resolve finite character positions after draw.""" + fig, axs = uplt.subplots(proj="polar") + ax = axs[0] + ax.format( + thetalim=(30, 120), + rlim=(0.4, 1.2), + thetalabel="Annular sector", + rlabel="rlabel", + ) + fig.canvas.draw() + for artist in (ax._thetalabel_artist, ax._rlabel_artist): + positions = [ + np.asarray(text.get_position(), dtype=float) + for char, text in artist._characters + if char.strip() + ] + assert positions + assert all(np.all(np.isfinite(position)) for position in positions) + + +def test_polar_format_wrapped_sector_uses_directed_interval(): + """Wrapped sectors must use the directed theta interval, not sorted extrema.""" + fig, axs = uplt.subplots(proj="polar") + ax = axs[0] + ax.format(thetalim=(300, 60), rlim=(0, 1), thetalabel="thetalabel", rlabel="rlabel") + tx, _ = ax._thetalabel_artist.get_curve() + mid_deg = np.rad2deg(0.5 * (tx[0] + tx[-1])) % 360 + assert np.isclose(mid_deg, 0.0) + rx, _ = ax._rlabel_artist.get_curve() + assert np.allclose(np.rad2deg(rx) % 360, 300.0) + ax.format(rlabelloc="left") + rx, _ = ax._rlabel_artist.get_curve() + assert np.allclose(np.rad2deg(rx) % 360, 60.0) + + +def test_polar_format_rlabelloc_full_circle_flips_offset(): + """On a full circle, `rlabelloc='left'` flips the perpendicular offset.""" + fig, axs = uplt.subplots(proj="polar") + ax = axs[0] + ax.format(rlabel="rlabel", rlabelloc="right") + fig.canvas.draw() + rpos_deg = ax.get_rlabel_position() + rmid = 0.5 * (ax.get_rmin() + ax.get_rmax()) + test_point = (np.deg2rad(rpos_deg), rmid) + right_base_disp = ax.transData.transform(test_point) + right_disp = ax._rlabel_artist.get_transform().transform(test_point) + ax.format(rlabelloc="left") + fig.canvas.draw() + left_base_disp = ax.transData.transform(test_point) + left_disp = ax._rlabel_artist.get_transform().transform(test_point) + right_off = right_disp - right_base_disp + left_off = left_disp - left_base_disp + assert np.allclose(right_off, -left_off) + assert not np.allclose(right_off, 0) + + +def test_polar_format_rlabelloc_sector_selects_spoke(): + """On a sector, `rlabelloc='right'` anchors to thetamin and `'left'` to thetamax.""" + fig, axs = uplt.subplots(proj="polar") + ax = axs[0] + ax.format(thetalim=(0, 90), rlabel="rlabel", rlabelloc="right") + rx_right, _ = ax._rlabel_artist.get_curve() + assert np.allclose(np.rad2deg(rx_right), 0.0) # thetamin spoke + ax.format(rlabelloc="left") + rx_left, _ = ax._rlabel_artist.get_curve() + assert np.allclose(np.rad2deg(rx_left), 90.0) # thetamax spoke + + +def test_polar_format_rlabelloc_sector_stays_outside_under_theta_transform(): + """Sector-default `rlabelloc` must stay outside after theta transforms.""" + fig, axs = uplt.subplots(ncols=2, proj="polar") + for ax, loc in zip(axs, ("right", "left")): + ax.format( + thetalim=(0, 180), + rlim=(0, 1), + theta0="N", + thetadir=-1, + rlabel="rlabel", + rlabelloc=loc, + ) + fig.canvas.draw() + for ax, rpos_deg, inside_deg in zip(axs, (0.0, 180.0), (1.0, 179.0)): + rmid = 0.5 * (ax.get_rmin() + ax.get_rmax()) + point = (np.deg2rad(rpos_deg), rmid) + base_disp = ax.transData.transform(point) + rlabel_disp = ax._rlabel_artist.get_transform().transform(point) + inside_disp = ax.transData.transform((np.deg2rad(inside_deg), rmid)) + off = rlabel_disp - base_disp + inside = inside_disp - base_disp + assert np.dot(off, inside) < 0 + + +def test_polar_format_loc_persists_across_format_calls(): + """ + A subsequent `format()` call without `thetalabelloc`/`rlabelloc`/`rlabelpos` + must not reset the previously-applied values. Regression test for trailing + `axs.format(suptitle=...)`-style calls. + """ + fig, axs = uplt.subplots(proj="polar") + ax = axs[0] + ax.format( + thetalim=(0, 90), + thetalabel="t", + thetalabelloc=30, + rlabel="r", + rlabelloc="left", + ) + tx0, _ = ax._thetalabel_artist.get_curve() + rx0, _ = ax._rlabel_artist.get_curve() + # Trailing generic format() call — must preserve the previous loc/pos. + ax.format(title="anything") + tx1, _ = ax._thetalabel_artist.get_curve() + rx1, _ = ax._rlabel_artist.get_curve() + assert np.isclose(np.rad2deg(0.5 * (tx1[0] + tx1[-1])), 30.0) + assert np.allclose(np.rad2deg(rx1), 90.0) + # Geometry is recomputed but anchors stay put. + assert np.allclose( + np.rad2deg(0.5 * (tx0[0] + tx0[-1])), np.rad2deg(0.5 * (tx1[0] + tx1[-1])) + ) + assert np.allclose(rx0, rx1) + + +def test_polar_format_rlabelpos_sector_auto_outside(): + """`rlabelpos=thetamax` on a sector offsets *outside* the wedge.""" + fig, axs = uplt.subplots(proj="polar") + ax = axs[0] + ax.format(thetalim=(0, 180), rlim=(0, 1), rlabel="rlabel", rlabelpos=180) + fig.canvas.draw() + rmid = 0.5 * (ax.get_rmin() + ax.get_rmax()) + test_point = (np.deg2rad(180.0), rmid) + base_disp = ax.transData.transform(test_point) + rlabel_disp = ax._rlabel_artist.get_transform().transform(test_point) + off = rlabel_disp - base_disp + # Spoke at theta=180 lies along the −x axis; the upper half-disk is +y, so + # outside-the-wedge means the perpendicular offset must be in −y. + assert off[1] < 0 + + +def test_polar_rlabel_offset_stays_perpendicular_under_theta_transform(): + """The rlabel offset must stay perpendicular to the spoke after theta transforms.""" + fig, axs = uplt.subplots(ncols=2, proj="polar") + for ax, kwargs in zip( + axs, + ({}, {"theta0": "N", "thetadir": -1}), + ): + ax.format( + thetalim=(0, 180), rlim=(0.2, 1), rlabel="rlabel", rlabelpos=135, **kwargs + ) + fig.canvas.draw() + for ax in axs: + rmid = 0.5 * (ax.get_rmin() + ax.get_rmax()) + point = (np.deg2rad(135.0), rmid) + base = ax.transData.transform(point) + disp = ax._rlabel_artist.get_transform().transform(point) + offset = disp - base + tangent = ax.transData.transform( + (np.deg2rad(135.0), ax.get_rmax()) + ) - ax.transData.transform((np.deg2rad(135.0), ax.get_rmin())) + tangent /= np.linalg.norm(tangent) + assert np.isclose(np.dot(offset, tangent), 0.0, atol=1e-6) + + +def test_polar_rlabel_refresh_tracks_tick_params(): + """Refreshing the rlabel must honor later tick-param changes.""" + fig, axs = uplt.subplots(proj="polar") + ax = axs[0] + ax.format(thetalim=(0, 180), rlim=(0.2, 1), rlabel="rlabel") + fig.canvas.draw() + rpos_deg = ax.get_rlabel_position() + rmid = 0.5 * (ax.get_rmin() + ax.get_rmax()) + point = (np.deg2rad(rpos_deg), rmid) + base = ax.transData.transform(point) + disp0 = ax._rlabel_artist.get_transform().transform(point) + off0 = np.linalg.norm(disp0 - base) + ax.tick_params(axis="y", which="major", pad=30, labelsize=20) + fig.canvas.draw() + disp1 = ax._rlabel_artist.get_transform().transform(point) + off1 = np.linalg.norm(disp1 - base) + assert off1 > off0 + + +def test_polar_labels_refresh_after_plot_draw(): + """Polar-aware label geometry must refresh when later plotting changes draw state.""" + fig, axs = uplt.subplots(proj="polar") + ax = axs[0] + ax.format( + thetalim=(0, 90), + rlim=(0, 2), + thetalabel="thetalabel", + rlabel="rlabel", + ) + tx0, ty0 = ax._thetalabel_artist.get_curve() + rx0, ry0 = ax._rlabel_artist.get_curve() + ax.plot(np.linspace(0, 2 * np.pi, 200), np.linspace(0, 100, 200)) + fig.canvas.draw() + tx1, ty1 = ax._thetalabel_artist.get_curve() + rx1, ry1 = ax._rlabel_artist.get_curve() + assert np.allclose(np.rad2deg(rx1), 0.0) + assert np.isclose(ry1[0], ax.get_rmin()) + assert np.isclose(ry1[-1], ax.get_rmax()) + assert np.allclose(ty1, ty1[0]) + assert ty1[0] >= ax.get_rmax() + assert ( + not np.allclose(ty0, ty1) + or not np.allclose(ry0, ry1) + or not np.allclose(tx0, tx1) + ) + + +def test_polar_labels_refresh_for_tightbbox(): + """Polar-aware labels must also refresh during tight-bbox queries.""" + fig, axs = uplt.subplots(proj="polar") + ax = axs[0] + ax.format(thetalim=(0, 90), rlim=(0, 1), thetalabel="thetalabel", rlabel="rlabel") + fig.canvas.draw() + tx0, ty0 = ax._thetalabel_artist.get_curve() + rx0, ry0 = ax._rlabel_artist.get_curve() + ax.set_rmax(3) + ax.get_tightbbox(fig.canvas.get_renderer()) + tx1, ty1 = ax._thetalabel_artist.get_curve() + rx1, ry1 = ax._rlabel_artist.get_curve() + assert np.allclose(np.rad2deg(rx1), 0.0) + assert np.isclose(ry1[-1], ax.get_rmax()) + assert np.allclose(ty1, ty1[0]) + assert ty1[0] >= ax.get_rmax() + assert ( + not np.allclose(ty0, ty1) + or not np.allclose(ry0, ry1) + or not np.allclose(tx0, tx1) + ) + + def test_sharing_axes(): """ Test sharing axes for GeoAxes diff --git a/ultraplot/tests/test_release_metadata.py b/ultraplot/tests/test_release_metadata.py index c6c570242..f5af944c2 100644 --- a/ultraplot/tests/test_release_metadata.py +++ b/ultraplot/tests/test_release_metadata.py @@ -131,19 +131,13 @@ def test_readme_citation_section_uses_repository_metadata(): assert "@software{" not in text -def test_publish_workflow_creates_github_release_and_pushes_to_zenodo(): +def test_publish_workflow_creates_github_release(): """ - Release tags should sync citation metadata, create a GitHub release, and - publish the same dist to Zenodo. + Release tags should sync citation metadata and create a GitHub release. """ text = PUBLISH_WORKFLOW.read_text(encoding="utf-8") assert 'tags: ["v*"]' in text assert "tools/release/sync_citation.py" in text assert "--tag" in text assert "--output" in text - assert "softprops/action-gh-release@v2" in text - assert "publish-zenodo:" in text - assert "ZENODO_ACCESS_TOKEN" in text - assert "tools/release/publish_zenodo.py" in text - assert "--dist-dir dist" in text - assert '--citation-file "${RUNNER_TEMP}/CITATION.cff"' in text + assert "softprops/action-gh-release@" in text diff --git a/ultraplot/tests/test_semantic_legend.py b/ultraplot/tests/test_semantic_legend.py new file mode 100644 index 000000000..f2091def1 --- /dev/null +++ b/ultraplot/tests/test_semantic_legend.py @@ -0,0 +1,544 @@ +""" +Unit tests for semantic legend style aliases, color parsing, and advanced markers. +These tests focus on functionality not covered by test_legend.py. +""" + +import matplotlib + +matplotlib.use("Agg") # non-interactive backend + +import numpy as np +import pytest +from matplotlib import colors as mcolors +from matplotlib import patches as mpatches +from matplotlib.markers import CapStyle, JoinStyle, MarkerStyle +import matplotlib.transforms as mtransforms + +import ultraplot as uplt + + +def _make_fig(): + """Helper to create a figure and axis with axes turned off.""" + fig, ax = uplt.subplots() + ax.axis("off") + return fig, ax + + +# ----------------------------------------------------------------------------- +# Non-color properties: scalar, list, dict (single catlegend call) +# ----------------------------------------------------------------------------- +def test_non_color_properties(): + """Non-color properties (marker, markersize, linewidth, alpha, fillstyle, + antialiased, markerfacecoloralt, markerfacecolor, markeredgecolor, size) + are correctly parsed and applied when passed together.""" + fig, ax = _make_fig() + try: + # Combine many non-color properties in one catlegend call. + h, _ = ax.catlegend( + ["A", "B", "C"], + marker="o", + ms=[10, 20, 30], # alias list – overrides above for each entry + lw=[1.5, 2.5, 3.5], # linewidth via alias list + alpha=[0.2, 0.5, 0.8], # length-3 list, not a color + fs="full", # fillstyle + aa=False, # antialiased scalar + markerfacecolor="green", # full name + markeredgecolor="black", # full name + markerfacecoloralt="orange", + line=True, # enable lines + add=False, + ) + # markersize from ms list + assert h[0].get_markersize() == 10 + assert h[1].get_markersize() == 20 + assert h[2].get_markersize() == 30 + # linewidth from lw list + assert h[0].get_linewidth() == 1.5 + assert h[1].get_linewidth() == 2.5 + assert h[2].get_linewidth() == 3.5 + # alpha + assert h[0].get_alpha() == 0.2 + assert h[1].get_alpha() == 0.5 + assert h[2].get_alpha() == 0.8 + # antialiased + for hh in h: + assert hh.get_antialiased() is False + for hh in h: + assert hh.get_markerfacecoloralt() == "orange" + assert hh.get_fillstyle() == "full" + finally: + uplt.close(fig) + + +def test_size_alias_and_markersize_dict(): + """'size' (collection style) maps to markersize, and dict works.""" + fig, ax = _make_fig() + try: + # size as list and dict + h, _ = ax.catlegend( + ["X", "Y", "Z"], + marker="s", + ms={"X": 5, "Y": 12, "Z": 20}, # dict should override per label + add=False, + ) + assert h[0].get_markersize() == 5 + assert h[1].get_markersize() == 12 + assert h[2].get_markersize() == 20 + finally: + uplt.close(fig) + + +def test_markerfacecolor_and_edgecolor(): + """Test full-name markerfacecolor and markeredgecolor with fillstyle='full'.""" + fig, ax = _make_fig() + try: + h, _ = ax.catlegend( + ["A", "B"], + marker="o", + markerfacecolor="green", + markeredgecolor="black", + add=False, + ) + for hh in h: + assert np.allclose( + mcolors.to_rgba(hh.get_markerfacecolor()), mcolors.to_rgba("green") + ) + assert np.allclose( + mcolors.to_rgba(hh.get_markeredgecolor()), mcolors.to_rgba("black") + ) + finally: + uplt.close(fig) + + +# ----------------------------------------------------------------------------- +# Alias resolution and conflicts +# ----------------------------------------------------------------------------- +def test_alias_resolution_and_conflicts(): + """Aliases (c, m, ms, ls, lw, mec, mew, mfc, mfcalt, aa, fs) work, + and full names override aliases when both are given.""" + fig, ax = _make_fig() + try: + # All aliases in one catlegend call + h, _ = ax.catlegend( + ["A", "B"], + c="red", + m="^", + ms=15, + ls="--", + lw=3.0, + mec="blue", + mew=2.0, + mfc="yellow", + mfcalt="orange", + aa=False, + fs="full", + add=False, + ) + for hh in h: + assert hh.get_color() == "red" + assert hh.get_marker() == "^" + assert hh.get_markersize() == 15 + assert hh.get_linestyle() == "--" + assert hh.get_linewidth() == 3.0 + assert hh.get_markeredgecolor() == "blue" + assert hh.get_markeredgewidth() == 2.0 + assert hh.get_markerfacecolor() == "yellow" + assert hh.get_markerfacecoloralt() == "orange" + assert hh.get_antialiased() is False + assert hh.get_fillstyle() == "full" + + # Conflict: full name overrides alias (markersize vs ms) + h, _ = ax.catlegend(["U", "V"], markersize=15, ms=99, add=False) + assert h[0].get_markersize() == 15 + + # Dict styles with aliases + h, _ = ax.catlegend( + ["red", "green", "blue"], + c={"red": "red", "green": "green", "blue": "blue"}, + ms={"red": 10, "green": 20, "blue": 30}, + add=False, + ) + assert h[0].get_color() == "red" + assert h[1].get_color() == "green" + assert h[2].get_color() == "blue" + assert h[0].get_markersize() == 10 + assert h[1].get_markersize() == 20 + assert h[2].get_markersize() == 30 + + # sizelegend aliases + h, _ = ax.sizelegend([1, 2, 3], c="purple", mec="green", add=False) + for hh in h: + assert hh.get_color() == "purple" + assert hh.get_markeredgecolor() == "green" + finally: + uplt.close(fig) + + +# ----------------------------------------------------------------------------- +# Color parsing: many formats (scalar, list, dict, tuple, etc.) +# ----------------------------------------------------------------------------- +def test_color_parsing(): + """Color parameters accept many formats (names, hex, tuples, lists, dicts), + and RGBA tuples are treated as single colors, not unpacked.""" + fig, ax = _make_fig() + try: + # Scalar colors: named, hex, grayscale, RGB tuple, RGBA tuple + for color in ["red", "#ff0000", "0.5", (0.2, 0.4, 0.6), (0.2, 0.4, 0.6, 0.8)]: + h, _ = ax.catlegend(["x", "y", "z"], color=color, add=False) + first = h[0].get_color() + assert all(hh.get_color() == first for hh in h), f"Failed for {color}" + + # List of colors: mixed formats + c_list = ["red", "#00ff00", (0.0, 0.0, 1.0)] + h, _ = ax.catlegend(["p", "q", "r"], color=c_list, add=False) + assert h[0].get_color() == c_list[0] + assert h[1].get_color() == c_list[1] + assert h[2].get_color() == c_list[2] + + # List of RGBA tuples + c_rgba = [(1.0, 0.0, 0.0, 1.0), (0.0, 1.0, 0.0, 1.0)] + h, _ = ax.catlegend(["X", "Y"], color=c_rgba, add=False) + assert h[0].get_color() == c_rgba[0] + assert h[1].get_color() == c_rgba[1] + + # Dict mapping labels to colors + color_dict = {"A": "red", "B": "green", "C": "blue"} + h, _ = ax.catlegend(["A", "B", "C"], color=color_dict, add=False) + assert h[0].get_color() == "red" + assert h[1].get_color() == "green" + assert h[2].get_color() == "blue" + + # markerfacecolor as single RGBA tuple + h, _ = ax.catlegend( + ["m1", "m2"], marker="o", markerfacecolor=(0.1, 0.2, 0.3, 1.0), add=False + ) + ref = h[0].get_markerfacecolor() + assert np.allclose(h[1].get_markerfacecolor(), ref) + + # markerfacecolor via alias (mfc) with list of colors + h, _ = ax.catlegend(["g", "l"], marker="o", mfc=["gold", "lime"], add=False) + assert np.allclose( + mcolors.to_rgba(h[0].get_markerfacecolor()), mcolors.to_rgba("gold") + ) + assert np.allclose( + mcolors.to_rgba(h[1].get_markerfacecolor()), mcolors.to_rgba("lime") + ) + + # numlegend facecolor as RGBA tuple + h, _ = ax.numlegend( + [1, 2, 3], vmin=0, vmax=4, facecolor=(0.8, 0.2, 0.3, 0.6), add=False + ) + ref_patch = np.array(h[0].get_facecolor()) + assert all(np.allclose(np.array(hh.get_facecolor()), ref_patch) for hh in h) + finally: + uplt.close(fig) + + +# ----------------------------------------------------------------------------- +# Advanced marker styles (capstyle, joinstyle, transform) +# ----------------------------------------------------------------------------- +def test_marker_advanced(): + """marker_capstyle, marker_joinstyle, marker_transform create MarkerStyle.""" + fig, ax = _make_fig() + try: + # cap & join + h, _ = ax.catlegend( + ["A", "B"], + marker_capstyle=[CapStyle.round, CapStyle.butt], + marker_joinstyle=[JoinStyle.miter, JoinStyle.bevel], + add=False, + ) + h[0]._marker.get_capstyle() == CapStyle.round + h[0]._marker.get_joinstyle() == JoinStyle.miter + h[1]._marker.get_capstyle() == CapStyle.butt + h[1]._marker.get_joinstyle() == JoinStyle.bevel + + # transform (rotation) + h, _ = ax.catlegend( + ["0°", "45°"], + marker_transform=[ + mtransforms.Affine2D().rotate_deg(0), + mtransforms.Affine2D().rotate_deg(45), + ], + add=False, + ) + h[0]._marker.get_transform().get_matrix()[ + :2, :2 + ] == mtransforms.Affine2D().rotate_deg(0).get_matrix()[:2, :2] + h[1]._marker.get_transform().get_matrix()[ + :2, :2 + ] == mtransforms.Affine2D().rotate_deg(45).get_matrix()[:2, :2] + + # combined with fillstyle and markerfacecoloralt + h, _ = ax.catlegend( + ["left", "right"], + marker="o", + markersize=25, + markerfacecolor="tab:blue", + markerfacecoloralt="lightsteelblue", + fillstyle=["left", "right"], + marker_capstyle=CapStyle.round, + marker_joinstyle="round", + add=False, + ) + assert len(h) == 2 + # Check each handle + for hh, expected_fillstyle in zip(h, ["left", "right"]): + # MarkerStyle creation + m = hh._marker + assert isinstance(m, MarkerStyle) + assert m.get_capstyle() == CapStyle.round + # 'round' string should be converted to JoinStyle.round by MarkerStyle + assert m.get_joinstyle() == JoinStyle.round + + # Check Line2D properties + assert hh.get_markersize() == 25 + assert np.allclose( + mcolors.to_rgba(hh.get_markerfacecolor()), mcolors.to_rgba("tab:blue") + ) + assert hh.get_markerfacecoloralt() == "lightsteelblue" + assert hh.get_fillstyle() == expected_fillstyle + finally: + uplt.close(fig) + + +# ----------------------------------------------------------------------------- +# Validation of forbidden legend kwargs +# ----------------------------------------------------------------------------- +def test_forbidden_legend_kwargs(): + """Passing 'label' or 'labels' to semantic helpers raises TypeError.""" + fig, ax = _make_fig() + try: + with pytest.raises(TypeError, match=r"Use title=\.\.\. for the legend title"): + ax.catlegend(["A"], label="Legend", add=True) + with pytest.raises( + TypeError, match="does not accept the legend kwarg 'labels'" + ): + ax.catlegend(["A"], labels=["x"], add=True) + finally: + uplt.close(fig) + + +# ----------------------------------------------------------------------------- +# Patch aliases and styles (numlegend, geolegend) +# ----------------------------------------------------------------------------- +def test_patch_aliases_and_styles(): + """numlegend and geolegend accept Patch aliases (fc, ec, ls, lw).""" + fig, ax = _make_fig() + try: + # numlegend with aliases + h, _ = ax.numlegend( + [1, 2], + vmin=0, + vmax=2, + fc=["red", "green"], + ec="black", + ls=":", + lw=1.5, + add=False, + ) + assert np.allclose(h[0].get_facecolor()[:3], mcolors.to_rgb("red")) + assert np.allclose(h[1].get_facecolor()[:3], mcolors.to_rgb("green")) + assert h[0].get_edgecolor()[:3] == (0, 0, 0) + assert h[0].get_linestyle() == ":" + assert h[0].get_linewidth() == 1.5 + + # geolegend shape existence + handles, labels = ax.geolegend( + [("Triangle", "triangle"), ("Hex", "hexagon")], add=False + ) + assert labels == ["Triangle", "Hex"] + assert all(isinstance(hh, mpatches.PathPatch) for hh in handles) + finally: + uplt.close(fig) + + +# ----------------------------------------------------------------------------- +# Linestyle auto-enables line +# ----------------------------------------------------------------------------- +def test_linestyle_auto_enable_line(): + """Providing a non-default linestyle automatically enables line=True.""" + fig, ax = _make_fig() + try: + h, _ = ax.catlegend(["A", "B"], ls="--", add=False) + for hh in h: + assert hh.get_linestyle() == "--" + # when line is enabled, marker becomes None + assert hh.get_marker() == uplt.rc["legend.cat.marker"] + finally: + uplt.close(fig) + + +# ----------------------------------------------------------------------------- +# geolegend: per‑entry lists +# ----------------------------------------------------------------------------- +def test_geolegend_per_entry_lists(): + """geolegend applies per-entry styles from lists (facecolor, edgecolor, linewidth, alpha, fill).""" + fig, ax = _make_fig() + try: + handles, labels = ax.geolegend( + ["box", "tri", "hex"], + facecolor=["tab:red", "tab:green", "tab:blue"], + edgecolor=["black", "gray", "white"], + linewidth=[1.0, 2.0, 3.0], + alpha=[0.5, 0.7, 1.0], + fill=[True, False, True], + add=False, + ) + assert len(handles) == 3 + assert labels == ["box", "tri", "hex"] + + # Check per-entry properties + expected_fc = ["tab:red", "tab:green", "tab:blue"] # None for fill=False + expected_ec = ["black", "gray", "white"] + expected_lw = [1.0, 2.0, 3.0] + expected_alpha = [0.5, 0.7, 1.0] + expected_fill = [True, False, True] + + for i, h in enumerate(handles): + assert isinstance(h, mpatches.PathPatch) + if expected_fill[i]: + assert np.allclose( + h.get_facecolor(), + mcolors.to_rgba(expected_fc[i], expected_alpha[i]), + ) + else: + # for fill=False, facecolor is preserved, and set alpha=0 + assert np.allclose( + mcolors.to_rgba(h.get_facecolor()[:3], 0), + mcolors.to_rgba(expected_fc[i], 0), + ) + assert np.allclose( + h.get_edgecolor(), mcolors.to_rgba(expected_ec[i], expected_alpha[i]) + ) + assert h.get_linewidth() == pytest.approx(expected_lw[i]) + assert h.get_alpha() == expected_alpha[i] + assert h.get_fill() == expected_fill[i] + finally: + uplt.close(fig) + + +# ----------------------------------------------------------------------------- +# geolegend: per‑entry dicts +# ----------------------------------------------------------------------------- +def test_geolegend_per_entry_dicts(): + """geolegend applies per-entry styles from dicts.""" + fig, ax = _make_fig() + try: + handles, labels = ax.geolegend( + ["box", "tri", "hex"], + facecolor={"box": "red", "tri": "green", "hex": "blue"}, + edgecolor={"box": "black", "tri": "gray", "hex": "white"}, + linewidth={"box": 1.0, "tri": 2.0, "hex": 3.0}, + alpha={"box": 0.5, "tri": 0.7, "hex": 1.0}, + fill={"box": True, "tri": False, "hex": True}, + add=False, + ) + assert len(handles) == 3 + assert labels == ["box", "tri", "hex"] + + expected = { + "box": ("red", "black", 1.0, 0.5, True), + "tri": ("green", "gray", 2.0, 0.7, False), + "hex": ("blue", "white", 3.0, 1.0, True), + } + for h, label in zip(handles, labels): + fc, ec, lw, alpha, fill = expected[label] + if fill: + assert np.allclose(h.get_facecolor(), mcolors.to_rgba(fc, alpha)) + else: + # for fill=False, facecolor is preserved, and set alpha=0 + assert np.allclose( + mcolors.to_rgba(h.get_facecolor()[:3], 0), mcolors.to_rgba(fc, 0) + ) + assert np.allclose(h.get_edgecolor(), mcolors.to_rgba(ec, alpha)) + assert h.get_linewidth() == pytest.approx(lw) + assert h.get_alpha() == alpha + assert h.get_fill() == fill + finally: + uplt.close(fig) + + +# ----------------------------------------------------------------------------- +# geolegend: alias support +# ----------------------------------------------------------------------------- +def test_geolegend_alias_support(): + """geolegend accepts aliases fc, ec, lw, ls, etc.""" + fig, ax = _make_fig() + try: + handles, _ = ax.geolegend( + ["box", "tri"], + fc=["red", "green"], # alias for facecolor + ec=["black", "blue"], # alias for edgecolor + lw=2.0, # alias for linewidth + ls="--", # alias for linestyle + add=False, + ) + assert len(handles) == 2 + # First geometry + h0 = handles[0] + assert np.allclose(h0.get_facecolor(), mcolors.to_rgba("red")) + assert np.allclose(h0.get_edgecolor(), mcolors.to_rgba("black")) + assert h0.get_linewidth() == 2.0 + assert h0.get_linestyle() == "--" + # Second geometry + h1 = handles[1] + assert np.allclose(h1.get_facecolor(), mcolors.to_rgba("green")) + assert np.allclose(h1.get_edgecolor(), mcolors.to_rgba("blue")) + finally: + uplt.close(fig) + + +# ----------------------------------------------------------------------------- +# geolegend: explicit parameter overrides alias (no conflict error) +# ----------------------------------------------------------------------------- +def test_geolegend_explicit_overrides_alias(): + """Explicit facecolor parameter overrides alias fc.""" + fig, ax = _make_fig() + try: + # facecolor='red' (explicit) vs fc='blue' (alias) → explicit wins + handles, _ = ax.geolegend( + ["box"], + facecolor="red", + fc="blue", + add=False, + ) + h = handles[0] + assert np.allclose(h.get_facecolor(), mcolors.to_rgba("red")) + # edgecolor explicit vs ec + handles, _ = ax.geolegend( + ["box"], + edgecolor="green", + ec="black", + add=False, + ) + h = handles[0] + assert np.allclose(h.get_edgecolor(), mcolors.to_rgba("green")) + finally: + uplt.close(fig) + + +# ----------------------------------------------------------------------------- +# geolegend: per-entry scalar applied to all +# ----------------------------------------------------------------------------- +def test_geolegend_scalar_applied_to_all(): + """Scalar styles are applied to all geometry entries.""" + fig, ax = _make_fig() + try: + handles, _ = ax.geolegend( + ["box", "tri", "hex"], + facecolor="cyan", + edgecolor="black", + linewidth=2.5, + alpha=0.6, + fill=True, + add=False, + ) + for h in handles: + assert np.allclose(h.get_facecolor(), mcolors.to_rgba("cyan", 0.6)) + assert np.allclose(h.get_edgecolor(), mcolors.to_rgba("black", 0.6)) + assert h.get_linewidth() == pytest.approx(2.5) + assert h.get_alpha() == 0.6 + assert h.get_fill() == True + finally: + uplt.close(fig) diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index 0f75cb125..458e9b902 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -284,6 +284,30 @@ def test_subset_share_xlabels_override(): uplt.close(fig) +def test_panel_subset_keeps_orthogonal_axis_labels_local(): + fig, axs = uplt.subplots(ncols=2, sharey=0) + bottom = axs.panel_axes("bottom") + right = axs.panel_axes("right") + + axs.format(xlabel="main x", ylabel="main y") + bottom.format(ylabel="bottom y") + right.format(xlabel="right x") + fig.canvas.draw() + + assert not fig._supylabel_dict + assert not fig._supxlabel_dict + for ax in axs: + assert ax.get_ylabel() == "main y" + for pax in bottom: + assert pax.get_ylabel() == "bottom y" + assert pax.yaxis.label.get_visible() + for pax in right: + assert pax.get_xlabel() == "right x" + assert pax.xaxis.label.get_visible() + + uplt.close(fig) + + def test_spanning_labels_excluded_from_tight_layout_bbox(): """ Spanning x/y labels should not be counted twice by tight layout. diff --git a/ultraplot/text.py b/ultraplot/text.py index bd123fce5..f00e18305 100644 --- a/ultraplot/text.py +++ b/ultraplot/text.py @@ -72,6 +72,11 @@ def __init__( if kwargs.get("transform") is None: kwargs["transform"] = axes.transData + # Split pseudo-properties (border/bbox*) from valid Text kwargs so + # mtext.Text(**self._text_kwargs) accepts them and pseudo-props can be + # re-applied via labels._update_label. + label_props, text_kwargs = labels._split_label_props(kwargs) + # Initialize storage before Text.__init__ triggers set_text() self._characters = [] self._curve_text = "" if text is None else str(text) @@ -82,11 +87,15 @@ def __init__( self._curvature_pad = float(curvature_pad) self._min_advance = float(min_advance) self._ellipsis_text = "..." - self._text_kwargs = kwargs.copy() + self._text_kwargs = text_kwargs + self._label_props = label_props self._initializing = True - super().__init__(x[0], y[0], " ", **kwargs) + super().__init__(x[0], y[0], " ", **text_kwargs) axes.add_artist(self) + # add_artist calls set_clip_path(self.patch), which sets clip_on=True + # and silently overrides any clip_on=False the caller passed. + self._restore_clip_on(self) self._curve_x = x self._curve_y = y @@ -95,18 +104,28 @@ def __init__( self._build_characters(self._curve_text) + def _restore_clip_on(self, t) -> None: + """Re-assert clip_on after add_artist/add_text resets it.""" + if "clip_on" in self._text_kwargs: + t.set_clip_on(self._text_kwargs["clip_on"]) + def _build_characters(self, text: str) -> None: # Remove previous character artists for _, artist in self._characters: artist.remove() self._characters = [] + # Initial position on the curve (not (0, 0)) so get_window_extent works + # under transforms whose inverse is undefined at (0, 0) — e.g. polar + # annular plots where r=0 is below rmin and inverts to NaN. + x0 = float(self._curve_x[0]) + y0 = float(self._curve_y[0]) for char in text: if char == " ": - t = mtext.Text(0, 0, " ", **self._text_kwargs) + t = mtext.Text(x0, y0, " ", **self._text_kwargs) t.set_alpha(0.0) else: - t = mtext.Text(0, 0, char, **self._text_kwargs) + t = mtext.Text(x0, y0, char, **self._text_kwargs) t.set_ha("center") t.set_va("center") @@ -117,6 +136,10 @@ def _build_characters(self, text: str) -> None: add_text(t) else: self.axes.add_artist(t) + self._restore_clip_on(t) + if self._label_props: + t.update = labels._update_label.__get__(t) + t.update(self._label_props) self._characters.append((char, t)) def set_text(self, s): @@ -143,6 +166,10 @@ def get_curve(self) -> Tuple[np.ndarray, np.ndarray]: return self._curve_x.copy(), self._curve_y.copy() def _apply_label_props(self, props) -> None: + new_label_props, new_text_kwargs = labels._split_label_props(props) + # Persist for future set_text() rebuilds. + self._text_kwargs.update(new_text_kwargs) + self._label_props.update(new_label_props) for _, t in self._characters: t.update = labels._update_label.__get__(t) t.update(props) @@ -153,6 +180,12 @@ def set_zorder(self, zorder): for _, t in self._characters: t.set_zorder(self._zorder + 1) + def set_transform(self, transform): + super().set_transform(transform) + self._text_kwargs["transform"] = transform + for _, t in self._characters: + t.set_transform(transform) + def draw(self, renderer, *args, **kwargs): """ Overload `Text.draw()` to update character positions and rotations. @@ -291,6 +324,16 @@ def _place_at(target, t): y_disp[idx] + fraction * dy_arr[idx], ] ) + # Pre-place at a valid data position before measuring bbox: on + # annular polar plots (rmin > 0) the default (0, 0) data coord + # falls below rmin and inverts to NaN, which propagates and + # locks the glyph at NaN forever. + try: + base_data = trans_inv.transform(base) + except Exception: + base_data = None + if base_data is not None and not np.any(np.isnan(base_data)): + t.set_position(base_data) t.set_va("center") bbox_center = t.get_window_extent(renderer=renderer) t.set_va(self.get_va())