diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 8ad5753d8..e875cd563 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -3191,13 +3191,31 @@ def get_tightbbox(self, renderer, *args, **kwargs): # Perform extra post-processing steps # NOTE: This should be updated alongside draw(). We also cache the resulting # bounding box to speed up tight layout calculations (see _range_tightbbox). + include_subset_titles = kwargs.pop("include_subset_titles", True) self._add_queued_guides() self._apply_title_above() if self._colorbar_fill: self._colorbar_fill.update_ticks(manual_only=True) # only if needed if self._inset_parent is not None and self._inset_zoom: self.indicate_inset_zoom() - self._tight_bbox = super().get_tightbbox(renderer, *args, **kwargs) + bbox = super().get_tightbbox(renderer, *args, **kwargs) + fig = self.figure + if ( + bbox is not None + and fig is not None + and self._panel_parent is None + and include_subset_titles + and hasattr(fig, "_get_subset_title_bbox") + ): + title_bbox = fig._get_subset_title_bbox(self, renderer) + if title_bbox is not None: + bbox = mtransforms.Bbox.from_extents( + bbox.xmin, + min(bbox.ymin, title_bbox.ymin), + bbox.xmax, + max(bbox.ymax, title_bbox.ymax), + ) + self._tight_bbox = bbox return self._tight_bbox def get_default_bbox_extra_artists(self): diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 4edab717d..4f1bedb92 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -11,9 +11,9 @@ from packaging import version try: - from typing import List, Optional, Tuple, Union + from typing import Any, Iterable, List, Optional, Tuple, Union except ImportError: - from typing_extensions import List, Optional, Tuple, Union + from typing_extensions import Any, Iterable, List, Optional, Tuple, Union import matplotlib.axes as maxes import matplotlib.figure as mfigure @@ -868,6 +868,7 @@ def _normalize_share(value): self._supylabel_dict = {} # an axes: label mapping self._suplabel_dict = {"left": {}, "right": {}, "bottom": {}, "top": {}} self._share_label_groups = {"x": {}, "y": {}} # explicit label-sharing groups + self._subset_title_dict = {} self._suptitle_pad = rc["suptitle.pad"] d = self._suplabel_props = {} # store the super label props d["left"] = {"va": "center", "ha": "right"} @@ -1662,7 +1663,9 @@ def _get_align_coord(self, side, axs, align="center", includepanels=False): ax = ax._panel_parent or ax # always use main subplot for spanning labels return pos, ax - def _get_offset_coord(self, side, axs, renderer, *, pad=None, extra=None): + def _get_offset_coord( + self, side, axs, renderer, *, pad=None, extra=None, include_subset_titles=True + ): """ Return the figure coordinate for offsetting super labels and super titles. """ @@ -1675,7 +1678,12 @@ def _get_offset_coord(self, side, axs, renderer, *, pad=None, extra=None): ) # noqa: E501 objs = objs + (extra or ()) # e.g. top super labels for obj in objs: - bbox = obj.get_tightbbox(renderer) # cannot use cached bbox + if isinstance(obj, paxes.Axes): + bbox = obj.get_tightbbox( + renderer, include_subset_titles=include_subset_titles + ) + else: + bbox = obj.get_tightbbox(renderer) # cannot use cached bbox attr = s + "max" if side in ("top", "right") else s + "min" c = getattr(bbox, attr) c = (c, 0) if side in ("left", "right") else (0, c) @@ -2523,6 +2531,12 @@ def _align_super_title(self, renderer): if not axs: return labs = tuple(t for t in self._suplabel_dict["top"].values() if t.get_text()) + subset_titles = tuple( + group["artist"] + for group in self._subset_title_dict.values() + if group["artist"].get_text() + ) + labs = labs + subset_titles pad = (self._suptitle_pad / 72) / self.get_size_inches()[1] # Get current alignment settings from suptitle (may be set via suptitle_kw) @@ -2548,6 +2562,183 @@ def _align_super_title(self, renderer): y = y_target - y_bbox self._suptitle.set_position((x, y)) + def _update_subset_title( + self, + axes: Iterable[paxes.Axes], + title: str | None, + *, + fontdict: dict[str, Any] | None = None, + loc: str | None = None, + pad: float | str | None = None, + y: float | None = None, + **kwargs: Any, + ) -> mtext.Text: + """ + Create or update a title spanning a subset of subplots. + """ + fontdict = _not_none(fontdict, kwargs.pop("fontdict", None)) + loc = _not_none( + loc, + kwargs.pop("loc", None), + rc.find("title.loc", context=True), + rc["title.loc"], + ) + pad = _not_none( + pad, + kwargs.pop("pad", None), + rc.find("title.pad", context=True), + rc["title.pad"], + ) + y = _not_none(y, kwargs.pop("y", None)) + axes = [ax for ax in axes if ax is not None and ax.figure is self] + if not axes: + raise ValueError("Need at least one axes to create a shared subplot title.") + + seen = set() + unique_axes = [] + for ax in axes: + ax = ax._panel_parent or ax + ax_id = id(ax) + if ax_id in seen: + continue + seen.add(ax_id) + unique_axes.append(ax) + axes = unique_axes + if len(axes) < 2: + return axes[0].set_title( + title, fontdict=fontdict, loc=loc, pad=pad, y=y, **kwargs + ) + + key = tuple(sorted(id(ax) for ax in axes)) + group = self._subset_title_dict.get(key) + kw = rc.fill( + { + "size": "title.size", + "weight": "title.weight", + "color": "title.color", + "family": "font.family", + }, + context=True, + ) + if "color" in kw and kw["color"] == "auto": + del kw["color"] + if fontdict: + kw.update(fontdict) + kw.update(kwargs) + align = _translate_loc(loc, "text") + match align: + case "left" | "outer left" | "upper left" | "lower left": + align = "left" + case "center" | "upper center" | "lower center": + align = "center" + case "right" | "outer right" | "upper right" | "lower right": + align = "right" + case _: + raise ValueError(f"Invalid shared subplot title location {loc!r}.") + if group is None: + artist = self.text( + 0.5, + 0.0, + "", + transform=self.transFigure, + ha=align, + va="baseline", + zorder=3.5, + ) + group = {"axes": axes, "artist": artist, "pad": None, "y": None} + self._subset_title_dict[key] = group + else: + artist = group["artist"] + group["axes"] = axes + group["pad"] = pad + group["y"] = y + artist.set_ha(align) + artist.set_va("baseline") + if title is not None: + artist.set_text(title) + if kw: + artist.update(kw) + return artist + + def _get_subset_title_bbox( + self, ax: paxes.Axes, renderer + ) -> mtransforms.Bbox | None: + """ + Return the union bbox for shared titles covering the given axes. + + Shared subset titles live above the subset's top edge, so they should + only contribute to the tight bounding boxes for axes that actually touch + that top boundary. Otherwise, multi-row subsets can incorrectly claim + the title as extra inter-row spacing. + """ + ax = ax._panel_parent or ax + bboxes = [] + for group in self._subset_title_dict.values(): + artist = group["artist"] + if not artist.get_visible() or not artist.get_text(): + continue + axs = [ + group_ax._panel_parent or group_ax + for group_ax in group["axes"] + if group_ax is not None + and group_ax.figure is self + and group_ax.get_visible() + ] + if not axs or ax not in axs: + continue + top = min(group_ax._range_subplotspec("y")[0] for group_ax in axs) + if ax._range_subplotspec("y")[0] == top: + bboxes.append(artist.get_window_extent(renderer)) + return mtransforms.Bbox.union(bboxes) if bboxes else None + + def _align_subset_titles(self, renderer): + """ + Update the positions of titles spanning subplot subsets. + """ + for key in list(self._subset_title_dict): + group = self._subset_title_dict[key] + artist = group["artist"] + axs = [ + ax + for ax in group["axes"] + if ax is not None and ax.figure is self and ax.get_visible() + ] + if not axs: + artist.remove() + del self._subset_title_dict[key] + continue + if not artist.get_text(): + continue + align = artist.get_ha() + x, _ = self._get_align_coord( + "top", + axs, + includepanels=self._includepanels, + align=align, + ) + top_labels = tuple( + lab + for ax, lab in self._suplabel_dict["top"].items() + if lab.get_text() and ax in axs + ) + artist.set_x(x) + manual_y = group["y"] + if manual_y is not None: + artist.set_y(manual_y) + continue + pad = group["pad"] + if pad is not None: + pad = units(pad, "pt") / (72 * self.get_size_inches()[1]) + y_target = self._get_offset_coord( + "top", + axs, + renderer, + pad=pad, + extra=top_labels, + include_subset_titles=False, + ) + artist.set_y(y_target) + def _update_axis_label(self, side, axs): """ Update the aligned axis label for the input axes. @@ -2777,6 +2968,7 @@ def _align_content(): # noqa: E306 self._align_axis_label(axis) for side in ("left", "right", "top", "bottom"): self._align_super_labels(side, renderer) + self._align_subset_titles(renderer) self._align_super_title(renderer) # Update the layout diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 5c4ac4066..bc583b563 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -20,6 +20,7 @@ from .config import rc from .internals import ( _not_none, + _pop_rc, docstring, ic, # noqa: F401 warnings, @@ -2083,29 +2084,52 @@ def format(self, **kwargs): share_ylabels = kwargs.get("share_ylabels", None) xlabel = kwargs.get("xlabel", None) ylabel = kwargs.get("ylabel", None) + title = kwargs.get("title", None) axes = [ax for ax in self if ax is not None] all_axes = set(self.figure._subplot_dict.values()) is_subset = bool(axes) and all_axes and set(axes) != all_axes - if len(self) > 1: - if share_xlabels is False: - self.figure._clear_share_label_groups(self, target="x") - if share_ylabels is False: - self.figure._clear_share_label_groups(self, target="y") - if not is_subset and share_xlabels is None and xlabel is not None: - self.figure._clear_share_label_groups(self, target="x") - if not is_subset and share_ylabels is None and ylabel is not None: - self.figure._clear_share_label_groups(self, target="y") - if is_subset and share_xlabels is None and xlabel is not None: - self.figure._register_share_label_group(self, target="x") - if is_subset and share_ylabels is None and ylabel is not None: - self.figure._register_share_label_group(self, target="y") - self.figure.format(axs=self, **kwargs) - # Refresh groups after labels are set - if len(self) > 1: - if is_subset and share_xlabels is None and xlabel is not None: - self.figure._register_share_label_group(self, target="x") - if is_subset and share_ylabels is None and ylabel is not None: - self.figure._register_share_label_group(self, target="y") + shared_subset_title = len(self) > 1 and is_subset and isinstance(title, str) + shared_title_kw = ( + dict(kwargs.pop("title_kw", None) or {}) if shared_subset_title else None + ) + if shared_subset_title: + kwargs.pop("title", None) + shared_title_loc = kwargs.pop("titleloc", None) + shared_title_pad = kwargs.pop("titlepad", None) + kwargs.pop("titleabove", None) + else: + shared_title_loc = None + shared_title_pad = None + rc_kw, rc_mode = _pop_rc(kwargs) + with rc.context(rc_kw, mode=rc_mode): + if len(self) > 1: + if share_xlabels is False: + self.figure._clear_share_label_groups(self, target="x") + if share_ylabels is False: + self.figure._clear_share_label_groups(self, target="y") + if not is_subset and share_xlabels is None and xlabel is not None: + self.figure._clear_share_label_groups(self, target="x") + if not is_subset and share_ylabels is None and ylabel is not None: + self.figure._clear_share_label_groups(self, target="y") + if is_subset and share_xlabels is None and xlabel is not None: + self.figure._register_share_label_group(self, target="x") + if is_subset and share_ylabels is None and ylabel is not None: + self.figure._register_share_label_group(self, target="y") + self.figure.format(axs=self, **kwargs) + if shared_subset_title: + self.figure._update_subset_title( + self, + title, + loc=shared_title_loc, + pad=shared_title_pad, + **(shared_title_kw or {}), + ) + # Refresh groups after labels are set + if len(self) > 1: + if is_subset and share_xlabels is None and xlabel is not None: + self.figure._register_share_label_group(self, target="x") + if is_subset and share_ylabels is None and ylabel is not None: + self.figure._register_share_label_group(self, target="y") def share_labels(self, *, axis="x"): """ diff --git a/ultraplot/tests/test_figure.py b/ultraplot/tests/test_figure.py index 066f3dd2a..53a297399 100644 --- a/ultraplot/tests/test_figure.py +++ b/ultraplot/tests/test_figure.py @@ -447,6 +447,22 @@ def test_suptitle_vertical_alignment_preserves_top_spacing(va): uplt.close("all") +def test_suptitle_clears_shared_subset_titles(): + fig, axs = uplt.subplots(nrows=2, ncols=2) + axs[0, :].format(title="Row title") + fig.format(suptitle="Figure title") + fig.canvas.draw() + renderer = fig.canvas.get_renderer() + + subset_title = next(iter(fig._subset_title_dict.values()))["artist"] + subset_bbox = subset_title.get_window_extent(renderer) + suptitle_bbox = fig._suptitle.get_window_extent(renderer) + + assert subset_bbox.y1 <= suptitle_bbox.y0 + + uplt.close("all") + + def test_subplots_pixelsnap_aligns_axes_bounds(): with uplt.rc.context({"subplots.pixelsnap": True}): fig, axs = uplt.subplots(ncols=2, nrows=2) diff --git a/ultraplot/tests/test_gridspec.py b/ultraplot/tests/test_gridspec.py index 3c8e8250b..e23b25e0f 100644 --- a/ultraplot/tests/test_gridspec.py +++ b/ultraplot/tests/test_gridspec.py @@ -1,4 +1,6 @@ import pytest +import numpy as np +import matplotlib.colors as mcolors import ultraplot as uplt from ultraplot.gridspec import SubplotGrid @@ -145,3 +147,146 @@ def test_gridspec_spanning_slice_deduplicates_axes(): legend = ax.get_legend() assert legend is not None assert [t.get_text() for t in legend.texts] == ["data"] + + +def test_subplotgrid_format_title_creates_shared_subset_title(): + fig, axs = uplt.subplots(nrows=2, ncols=2) + + subset = axs[:, 0] + subset.format(title="Shared title") + fig.canvas.draw() + + title = next(iter(fig._subset_title_dict.values()))["artist"] + assert title.get_text() == "Shared title" + assert all(not ax.get_title() for ax in subset) + + x_expected, _ = fig._get_align_coord("top", list(subset), align="center") + bbox = title.get_window_extent(fig._get_renderer()).transformed( + fig.transFigure.inverted() + ) + top = max(ax.get_position().y1 for ax in subset) + assert np.isclose(title.get_position()[0], x_expected) + assert bbox.y0 > top + + +def test_subplotgrid_format_title_uses_rc_defaults(): + with uplt.rc.context({"title.loc": "left"}): + fig, axs = uplt.subplots(nrows=2, ncols=2) + subset = axs[:, 0] + subset.format(title="Shared title") + fig.canvas.draw() + + title = next(iter(fig._subset_title_dict.values()))["artist"] + x_expected, _ = fig._get_align_coord("top", list(subset), align="left") + assert title.get_ha() == "left" + assert np.isclose(title.get_position()[0], x_expected) + + +def test_subplotgrid_format_title_uses_format_rc_settings(): + fig, axs = uplt.subplots(nrows=2, ncols=2) + subset = axs[:, 0] + subset.format(title="Shared title", titlesize=22, titlecolor="red") + fig.canvas.draw() + + title = next(iter(fig._subset_title_dict.values()))["artist"] + assert title.get_fontsize() == 22 + assert np.allclose(mcolors.to_rgba(title.get_color()), mcolors.to_rgba("red")) + + +@pytest.mark.parametrize( + ("loc", "ha"), + [ + ("upper left", "left"), + ("lower center", "center"), + ("outer right", "right"), + ], +) +def test_subplotgrid_format_title_accepts_standard_title_locations(loc, ha): + fig, axs = uplt.subplots(nrows=2, ncols=2) + subset = axs[:, 0] + subset.format(title="Shared title", titleloc=loc) + fig.canvas.draw() + + title = next(iter(fig._subset_title_dict.values()))["artist"] + x_expected, _ = fig._get_align_coord("top", list(subset), align=ha) + assert title.get_ha() == ha + assert np.isclose(title.get_position()[0], x_expected) + + +def test_subplotgrid_format_title_matches_axes_title_top_gap(): + fig, axs = uplt.subplots(ncols=3) + axs[0].format(title="Single") + subset = axs[1:] + subset.format(title="Shared") + fig.canvas.draw() + + renderer = fig._get_renderer() + single = axs[0]._title_dict["center"] + shared = next(iter(fig._subset_title_dict.values()))["artist"] + single_top = fig.transFigure.transform((0, axs[0].get_position().y1))[1] + shared_top = fig.transFigure.transform((0, axs[1].get_position().y1))[1] + single_gap = single.get_window_extent(renderer).y0 - single_top + shared_gap = shared.get_window_extent(renderer).y0 - shared_top + + assert np.isclose(single_gap, shared_gap) + + +def test_subplotgrid_format_title_across_rows_does_not_inflate_hspace(): + fig_plain, axs_plain = uplt.subplots(ncols=4, nrows=2, refwidth=1) + fig_plain.canvas.draw() + plain_hspace = fig_plain.gridspec.hspace_total[0] + + fig, axs = uplt.subplots(ncols=4, nrows=2, refwidth=1) + subset = axs[:, :3] + subset.format(title="A test title") + fig.canvas.draw() + + title = next(iter(fig._subset_title_dict.values()))["artist"] + bbox = title.get_window_extent(fig._get_renderer()).transformed( + fig.transFigure.inverted() + ) + top = max(ax.get_position().y1 for ax in subset[:3]) + + assert bbox.y0 > top + assert fig.gridspec.hspace_total[0] < plain_hspace + 0.2 + + +def test_subplotgrid_format_title_allows_vertical_alignment_override(): + fig, axs = uplt.subplots(nrows=2, ncols=2) + subset = axs[:, 0] + subset.format(title="Shared title", title_kw={"va": "bottom"}) + fig.canvas.draw() + + title = next(iter(fig._subset_title_dict.values()))["artist"] + assert title.get_va() == "bottom" + + +def test_subplotgrid_format_title_clears_bottom_colorbar_panels(): + fig, axs = uplt.subplots(nrows=2, ncols=2, refwidth=2.5, share=False) + data = np.random.random((10, 10)) + top = axs[:2] + top[0].contourf(data, colorbar="b") + top[1].pcolormesh(data, colorbar="b") + bottom = axs[2:] + bottom.format(title="Shared title") + fig.canvas.draw() + + renderer = fig._get_renderer() + title = next(iter(fig._subset_title_dict.values()))["artist"] + title_bbox = title.get_window_extent(renderer) + panel_y0 = min( + panel.get_tightbbox(renderer).y0 + for ax in top + for panel in ax._panel_dict["bottom"] + ) + assert title_bbox.y1 <= panel_y0 + + +def test_subplotgrid_set_title_still_applies_per_axes(): + fig, axs = uplt.subplots(nrows=1, ncols=2) + + titles = axs[:].set_title("Shared title") + + assert isinstance(titles, tuple) + assert len(titles) == 2 + assert [ax.get_title() for ax in axs] == ["Shared title", "Shared title"]