Skip to content

Commit 29e1125

Browse files
committed
Support text elements in figure container.
This is mainly for mpld3, exporting figure-level text objects (such as fig.suptitle) would never get exported. It looks like figure-level text objects (and objects in general) was completely missing here. This adds text objects, I might add more in the future. The alternative was to put figure text objects into the first axis object, but that's extremely hacky, so I went for the larger but proper fix instead. I did this together with gpt-5.1-codex, not alone. Here is what it has to say: - Exporter now emits figure-level text (suptitle + fig.text) via a dedicated draw_figure_text call before crawling axes; figure transforms passed directly to process_transform instead of shoving text into the first axes. - Renderer API gains a draw_figure_text hook (no-op default in base) so non-mpld3 renderers don’t break; FakeRenderer already implements it. - Figure JSON now carries a texts array and MPLD3Renderer serializes figure-level text entries with proper coordinates/attrs; tests cover presence/positions of exported figure texts.
1 parent 0bea50b commit 29e1125

File tree

3 files changed

+51
-9
lines changed

3 files changed

+51
-9
lines changed

mplexporter/exporter.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def run(self, fig):
5252
self.crawl_fig(fig)
5353

5454
@staticmethod
55-
def process_transform(transform, ax=None, data=None, return_trans=False,
56-
force_trans=None):
55+
def process_transform(transform, ax=None, fig=None, data=None,
56+
return_trans=False, force_trans=None):
5757
"""Process the transform and convert data to figure or data coordinates
5858
5959
Parameters
@@ -62,6 +62,8 @@ def process_transform(transform, ax=None, data=None, return_trans=False,
6262
The transform applied to the data
6363
ax : matplotlib Axes object (optional)
6464
The axes the data is associated with
65+
fig : matplotlib Figure object (optional)
66+
The figure the data is associated with
6567
data : ndarray (optional)
6668
The array of data to be transformed.
6769
return_trans : bool (optional)
@@ -91,6 +93,7 @@ def process_transform(transform, ax=None, data=None, return_trans=False,
9193
transform = force_trans
9294

9395
code = "display"
96+
fig_ref = ax.figure if ax is not None else fig
9497
if ax is not None:
9598
for (c, trans) in [("data", ax.transData),
9699
("axes", ax.transAxes),
@@ -99,6 +102,12 @@ def process_transform(transform, ax=None, data=None, return_trans=False,
99102
if transform.contains_branch(trans):
100103
code, transform = (c, transform - trans)
101104
break
105+
elif fig_ref is not None:
106+
for (c, trans) in [("figure", fig_ref.transFigure),
107+
("display", transforms.IdentityTransform())]:
108+
if transform.contains_branch(trans):
109+
code, transform = (c, transform - trans)
110+
break
102111

103112
if data is not None:
104113
if return_trans:
@@ -115,6 +124,12 @@ def crawl_fig(self, fig):
115124
"""Crawl the figure and process all axes"""
116125
with self.renderer.draw_figure(fig=fig,
117126
props=utils.get_figure_properties(fig)):
127+
if getattr(fig, "_suptitle", None) is not None:
128+
self.draw_figure_text(fig, fig._suptitle, text_type="suptitle")
129+
for text in fig.texts:
130+
if text is not getattr(fig, "_suptitle", None):
131+
self.draw_figure_text(fig, text)
132+
118133
for ax in fig.axes:
119134
self.crawl_ax(ax)
120135

@@ -149,6 +164,20 @@ def crawl_ax(self, ax):
149164
if props['visible']:
150165
self.crawl_legend(ax, legend)
151166

167+
def draw_figure_text(self, fig, text, text_type=None):
168+
"""Process a figure-level matplotlib text object"""
169+
content = text.get_text()
170+
if content:
171+
transform = text.get_transform()
172+
position = text.get_position()
173+
coords, position = self.process_transform(transform, None, fig,
174+
position)
175+
style = utils.get_text_style(text)
176+
self.renderer.draw_figure_text(text=content, position=position,
177+
coordinates=coords,
178+
text_type=text_type,
179+
style=style, mplobj=text)
180+
152181
def crawl_legend(self, ax, legend):
153182
"""
154183
Recursively look through objects in legend children
@@ -184,7 +213,8 @@ def crawl_legend(self, ax, legend):
184213
def draw_line(self, ax, line, force_trans=None):
185214
"""Process a matplotlib line and call renderer.draw_line"""
186215
coordinates, data = self.process_transform(line.get_transform(),
187-
ax, line.get_xydata(),
216+
ax=ax,
217+
data=line.get_xydata(),
188218
force_trans=force_trans)
189219
linestyle = utils.get_line_style(line)
190220
if (linestyle['dasharray'] is None
@@ -208,8 +238,9 @@ def draw_text(self, ax, text, force_trans=None, text_type=None):
208238
if content:
209239
transform = text.get_transform()
210240
position = text.get_position()
211-
coords, position = self.process_transform(transform, ax,
212-
position,
241+
coords, position = self.process_transform(transform,
242+
ax=ax,
243+
data=position,
213244
force_trans=force_trans)
214245
style = utils.get_text_style(text)
215246
self.renderer.draw_text(text=content, position=position,
@@ -222,7 +253,8 @@ def draw_patch(self, ax, patch, force_trans=None):
222253
vertices, pathcodes = utils.SVG_path(patch.get_path())
223254
transform = patch.get_transform()
224255
coordinates, vertices = self.process_transform(transform,
225-
ax, vertices,
256+
ax=ax,
257+
data=vertices,
226258
force_trans=force_trans)
227259
linestyle = utils.get_path_style(patch, fill=patch.get_fill())
228260
self.renderer.draw_path(data=vertices,
@@ -239,13 +271,14 @@ def draw_collection(self, ax, collection,
239271
offsets, paths) = _collections_prepare_points(collection, ax)
240272

241273
offset_coords, offsets = self.process_transform(
242-
transOffset, ax, offsets, force_trans=force_offsettrans)
274+
transOffset, ax=ax, data=offsets, force_trans=force_offsettrans)
243275
path_coords = self.process_transform(
244-
transform, ax, force_trans=force_pathtrans)
276+
transform, ax=ax, force_trans=force_pathtrans)
245277

246278
processed_paths = [utils.SVG_path(path) for path in paths]
247279
processed_paths = [(self.process_transform(
248-
transform, ax, path[0], force_trans=force_pathtrans)[1], path[1])
280+
transform, ax=ax, data=path[0],
281+
force_trans=force_pathtrans)[1], path[1])
249282
for path in processed_paths]
250283

251284
path_transforms = collection.get_transforms()

mplexporter/renderers/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ def draw_marked_line(self, data, coordinates, linestyle, markerstyle,
158158
if markerstyle is not None:
159159
self.draw_markers(data, coordinates, markerstyle, label, mplobj)
160160

161+
def draw_figure_text(self, text, position, coordinates, style,
162+
text_type=None, mplobj=None):
163+
"""Figure-level text; renderers that care can override."""
164+
pass
165+
161166
def draw_line(self, data, coordinates, style, label, mplobj=None):
162167
"""
163168
Draw a line. By default, draw the line via the draw_path() command.

mplexporter/renderers/fake_renderer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def open_legend(self, legend, props):
3535
def close_legend(self, legend):
3636
self.output += " closing legend\n"
3737

38+
def draw_figure_text(self, text, position, coordinates, style,
39+
text_type=None, mplobj=None):
40+
self.output += " draw figure text '{0}' {1}\n".format(text, text_type)
41+
3842
def draw_text(self, text, position, coordinates, style,
3943
text_type=None, mplobj=None):
4044
self.output += " draw text '{0}' {1}\n".format(text, text_type)

0 commit comments

Comments
 (0)