Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion ultraplot/axes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3191,13 +3191,31 @@ def get_tightbbox(self, renderer, *args, **kwargs):
# Perform extra post-processing steps
# NOTE: This should be updated alongside draw(). We also cache the resulting
# bounding box to speed up tight layout calculations (see _range_tightbbox).
include_subset_titles = kwargs.pop("include_subset_titles", True)
self._add_queued_guides()
self._apply_title_above()
if self._colorbar_fill:
self._colorbar_fill.update_ticks(manual_only=True) # only if needed
if self._inset_parent is not None and self._inset_zoom:
self.indicate_inset_zoom()
self._tight_bbox = super().get_tightbbox(renderer, *args, **kwargs)
bbox = super().get_tightbbox(renderer, *args, **kwargs)
fig = self.figure
if (
bbox is not None
and fig is not None
and self._panel_parent is None
and include_subset_titles
and hasattr(fig, "_get_subset_title_bbox")
):
title_bbox = fig._get_subset_title_bbox(self, renderer)
if title_bbox is not None:
bbox = mtransforms.Bbox.from_extents(
bbox.xmin,
min(bbox.ymin, title_bbox.ymin),
bbox.xmax,
max(bbox.ymax, title_bbox.ymax),
)
self._tight_bbox = bbox
return self._tight_bbox

def get_default_bbox_extra_artists(self):
Expand Down
192 changes: 188 additions & 4 deletions ultraplot/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from packaging import version

try:
from typing import List, Optional, Tuple, Union
from typing import Any, Iterable, List, Optional, Tuple, Union
except ImportError:
from typing_extensions import List, Optional, Tuple, Union
from typing_extensions import Any, Iterable, List, Optional, Tuple, Union

import matplotlib.axes as maxes
import matplotlib.figure as mfigure
Expand Down Expand Up @@ -868,6 +868,7 @@ def _normalize_share(value):
self._supylabel_dict = {} # an axes: label mapping
self._suplabel_dict = {"left": {}, "right": {}, "bottom": {}, "top": {}}
self._share_label_groups = {"x": {}, "y": {}} # explicit label-sharing groups
self._subset_title_dict = {}
self._suptitle_pad = rc["suptitle.pad"]
d = self._suplabel_props = {} # store the super label props
d["left"] = {"va": "center", "ha": "right"}
Expand Down Expand Up @@ -1662,7 +1663,9 @@ def _get_align_coord(self, side, axs, align="center", includepanels=False):
ax = ax._panel_parent or ax # always use main subplot for spanning labels
return pos, ax

def _get_offset_coord(self, side, axs, renderer, *, pad=None, extra=None):
def _get_offset_coord(
self, side, axs, renderer, *, pad=None, extra=None, include_subset_titles=True
):
"""
Return the figure coordinate for offsetting super labels and super titles.
"""
Expand All @@ -1675,7 +1678,12 @@ def _get_offset_coord(self, side, axs, renderer, *, pad=None, extra=None):
) # noqa: E501
objs = objs + (extra or ()) # e.g. top super labels
for obj in objs:
bbox = obj.get_tightbbox(renderer) # cannot use cached bbox
if isinstance(obj, paxes.Axes):
bbox = obj.get_tightbbox(
renderer, include_subset_titles=include_subset_titles
)
else:
bbox = obj.get_tightbbox(renderer) # cannot use cached bbox
attr = s + "max" if side in ("top", "right") else s + "min"
c = getattr(bbox, attr)
c = (c, 0) if side in ("left", "right") else (0, c)
Expand Down Expand Up @@ -2523,6 +2531,12 @@ def _align_super_title(self, renderer):
if not axs:
return
labs = tuple(t for t in self._suplabel_dict["top"].values() if t.get_text())
subset_titles = tuple(
group["artist"]
for group in self._subset_title_dict.values()
if group["artist"].get_text()
)
labs = labs + subset_titles
pad = (self._suptitle_pad / 72) / self.get_size_inches()[1]

# Get current alignment settings from suptitle (may be set via suptitle_kw)
Expand All @@ -2548,6 +2562,175 @@ def _align_super_title(self, renderer):
y = y_target - y_bbox
self._suptitle.set_position((x, y))

def _update_subset_title(
self,
axes: Iterable[paxes.Axes],
title: str | None,
*,
fontdict: dict[str, Any] | None = None,
loc: str | None = None,
pad: float | str | None = None,
y: float | None = None,
**kwargs: Any,
) -> mtext.Text:
"""
Create or update a title spanning a subset of subplots.
"""
fontdict = _not_none(fontdict, kwargs.pop("fontdict", None))
loc = _not_none(
loc,
kwargs.pop("loc", None),
rc.find("title.loc", context=True),
rc["title.loc"],
)
pad = _not_none(
pad,
kwargs.pop("pad", None),
rc.find("title.pad", context=True),
rc["title.pad"],
)
y = _not_none(y, kwargs.pop("y", None))
axes = [ax for ax in axes if ax is not None and ax.figure is self]
if not axes:
raise ValueError("Need at least one axes to create a shared subplot title.")

seen = set()
unique_axes = []
for ax in axes:
ax = ax._panel_parent or ax
ax_id = id(ax)
if ax_id in seen:
continue
seen.add(ax_id)
unique_axes.append(ax)
axes = unique_axes
if len(axes) < 2:
return axes[0].set_title(
title, fontdict=fontdict, loc=loc, pad=pad, y=y, **kwargs
)

key = tuple(sorted(id(ax) for ax in axes))
group = self._subset_title_dict.get(key)
kw = rc.fill(
{
"size": "title.size",
"weight": "title.weight",
"color": "title.color",
"family": "font.family",
},
context=True,
)
if "color" in kw and kw["color"] == "auto":
del kw["color"]
if fontdict:
kw.update(fontdict)
kw.update(kwargs)
align = _translate_loc(loc, "text")
match align:
case "left" | "outer left" | "upper left" | "lower left":
align = "left"
case "center" | "upper center" | "lower center":
align = "center"
case "right" | "outer right" | "upper right" | "lower right":
align = "right"
case _:
raise ValueError(f"Invalid shared subplot title location {loc!r}.")
if group is None:
artist = self.text(
0.5,
0.0,
"",
transform=self.transFigure,
ha=align,
va="baseline",
zorder=3.5,
)
group = {"axes": axes, "artist": artist, "pad": None, "y": None}
self._subset_title_dict[key] = group
else:
artist = group["artist"]
group["axes"] = axes
group["pad"] = pad
group["y"] = y
artist.set_ha(align)
artist.set_va("baseline")
if title is not None:
artist.set_text(title)
if kw:
artist.update(kw)
return artist

def _get_subset_title_bbox(
self, ax: paxes.Axes, renderer
) -> mtransforms.Bbox | None:
"""
Return the union bbox for shared titles covering the given axes.
"""
ax = ax._panel_parent or ax
bboxes = []
for group in self._subset_title_dict.values():
artist = group["artist"]
if not artist.get_visible() or not artist.get_text():
continue
axs = [
group_ax._panel_parent or group_ax
for group_ax in group["axes"]
if group_ax is not None
and group_ax.figure is self
and group_ax.get_visible()
]
if ax in axs:
bboxes.append(artist.get_window_extent(renderer))
return mtransforms.Bbox.union(bboxes) if bboxes else None

def _align_subset_titles(self, renderer):
"""
Update the positions of titles spanning subplot subsets.
"""
for key in list(self._subset_title_dict):
group = self._subset_title_dict[key]
artist = group["artist"]
axs = [
ax
for ax in group["axes"]
if ax is not None and ax.figure is self and ax.get_visible()
]
if not axs:
artist.remove()
del self._subset_title_dict[key]
continue
if not artist.get_text():
continue
align = artist.get_ha()
x, _ = self._get_align_coord(
"top",
axs,
includepanels=self._includepanels,
align=align,
)
top_labels = tuple(
lab
for ax, lab in self._suplabel_dict["top"].items()
if lab.get_text() and ax in axs
)
artist.set_x(x)
manual_y = group["y"]
if manual_y is not None:
artist.set_y(manual_y)
continue
pad = group["pad"]
if pad is not None:
pad = units(pad, "pt") / (72 * self.get_size_inches()[1])
y_target = self._get_offset_coord(
"top",
axs,
renderer,
pad=pad,
extra=top_labels,
include_subset_titles=False,
)
artist.set_y(y_target)

def _update_axis_label(self, side, axs):
"""
Update the aligned axis label for the input axes.
Expand Down Expand Up @@ -2777,6 +2960,7 @@ def _align_content(): # noqa: E306
self._align_axis_label(axis)
for side in ("left", "right", "top", "bottom"):
self._align_super_labels(side, renderer)
self._align_subset_titles(renderer)
self._align_super_title(renderer)

# Update the layout
Expand Down
64 changes: 44 additions & 20 deletions ultraplot/gridspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .config import rc
from .internals import (
_not_none,
_pop_rc,
docstring,
ic, # noqa: F401
warnings,
Expand Down Expand Up @@ -2083,29 +2084,52 @@ def format(self, **kwargs):
share_ylabels = kwargs.get("share_ylabels", None)
xlabel = kwargs.get("xlabel", None)
ylabel = kwargs.get("ylabel", None)
title = kwargs.get("title", None)
axes = [ax for ax in self if ax is not None]
all_axes = set(self.figure._subplot_dict.values())
is_subset = bool(axes) and all_axes and set(axes) != all_axes
if len(self) > 1:
if share_xlabels is False:
self.figure._clear_share_label_groups(self, target="x")
if share_ylabels is False:
self.figure._clear_share_label_groups(self, target="y")
if not is_subset and share_xlabels is None and xlabel is not None:
self.figure._clear_share_label_groups(self, target="x")
if not is_subset and share_ylabels is None and ylabel is not None:
self.figure._clear_share_label_groups(self, target="y")
if is_subset and share_xlabels is None and xlabel is not None:
self.figure._register_share_label_group(self, target="x")
if is_subset and share_ylabels is None and ylabel is not None:
self.figure._register_share_label_group(self, target="y")
self.figure.format(axs=self, **kwargs)
# Refresh groups after labels are set
if len(self) > 1:
if is_subset and share_xlabels is None and xlabel is not None:
self.figure._register_share_label_group(self, target="x")
if is_subset and share_ylabels is None and ylabel is not None:
self.figure._register_share_label_group(self, target="y")
shared_subset_title = len(self) > 1 and is_subset and isinstance(title, str)
shared_title_kw = (
dict(kwargs.pop("title_kw", None) or {}) if shared_subset_title else None
)
if shared_subset_title:
kwargs.pop("title", None)
shared_title_loc = kwargs.pop("titleloc", None)
shared_title_pad = kwargs.pop("titlepad", None)
kwargs.pop("titleabove", None)
else:
shared_title_loc = None
shared_title_pad = None
rc_kw, rc_mode = _pop_rc(kwargs)
with rc.context(rc_kw, mode=rc_mode):
if len(self) > 1:
if share_xlabels is False:
self.figure._clear_share_label_groups(self, target="x")
if share_ylabels is False:
self.figure._clear_share_label_groups(self, target="y")
if not is_subset and share_xlabels is None and xlabel is not None:
self.figure._clear_share_label_groups(self, target="x")
if not is_subset and share_ylabels is None and ylabel is not None:
self.figure._clear_share_label_groups(self, target="y")
if is_subset and share_xlabels is None and xlabel is not None:
self.figure._register_share_label_group(self, target="x")
if is_subset and share_ylabels is None and ylabel is not None:
self.figure._register_share_label_group(self, target="y")
self.figure.format(axs=self, **kwargs)
if shared_subset_title:
self.figure._update_subset_title(
self,
title,
loc=shared_title_loc,
pad=shared_title_pad,
**(shared_title_kw or {}),
)
# Refresh groups after labels are set
if len(self) > 1:
if is_subset and share_xlabels is None and xlabel is not None:
self.figure._register_share_label_group(self, target="x")
if is_subset and share_ylabels is None and ylabel is not None:
self.figure._register_share_label_group(self, target="y")

def share_labels(self, *, axis="x"):
"""
Expand Down
16 changes: 16 additions & 0 deletions ultraplot/tests/test_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,22 @@ def test_suptitle_vertical_alignment_preserves_top_spacing(va):
uplt.close("all")


def test_suptitle_clears_shared_subset_titles():
fig, axs = uplt.subplots(nrows=2, ncols=2)
axs[0, :].format(title="Row title")
fig.format(suptitle="Figure title")
fig.canvas.draw()
renderer = fig.canvas.get_renderer()

subset_title = next(iter(fig._subset_title_dict.values()))["artist"]
subset_bbox = subset_title.get_window_extent(renderer)
suptitle_bbox = fig._suptitle.get_window_extent(renderer)

assert subset_bbox.y1 <= suptitle_bbox.y0

uplt.close("all")


def test_subplots_pixelsnap_aligns_axes_bounds():
with uplt.rc.context({"subplots.pixelsnap": True}):
fig, axs = uplt.subplots(ncols=2, nrows=2)
Expand Down
Loading
Loading