Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions mplexporter/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def run(self, fig):
self.crawl_fig(fig)

@staticmethod
def process_transform(transform, ax=None, data=None, return_trans=False,
force_trans=None):
def process_transform(transform, ax=None, fig=None, data=None,
return_trans=False, force_trans=None):
"""Process the transform and convert data to figure or data coordinates

Parameters
Expand All @@ -61,6 +61,8 @@ def process_transform(transform, ax=None, data=None, return_trans=False,
The transform applied to the data
ax : matplotlib Axes object (optional)
The axes the data is associated with
fig : matplotlib Figure object (optional)
The figure the data is associated with
data : ndarray (optional)
The array of data to be transformed.
return_trans : bool (optional)
Expand Down Expand Up @@ -90,6 +92,7 @@ def process_transform(transform, ax=None, data=None, return_trans=False,
transform = force_trans

code = "display"
fig_ref = ax.figure if ax is not None else fig
if ax is not None:
for (c, trans) in [("data", ax.transData),
("axes", ax.transAxes),
Expand All @@ -98,6 +101,12 @@ def process_transform(transform, ax=None, data=None, return_trans=False,
if transform.contains_branch(trans):
code, transform = (c, transform - trans)
break
elif fig_ref is not None:
for (c, trans) in [("figure", fig_ref.transFigure),
("display", transforms.IdentityTransform())]:
if transform.contains_branch(trans):
code, transform = (c, transform - trans)
break

if data is not None:
if return_trans:
Expand All @@ -114,6 +123,12 @@ def crawl_fig(self, fig):
"""Crawl the figure and process all axes"""
with self.renderer.draw_figure(fig=fig,
props=utils.get_figure_properties(fig)):
if getattr(fig, "_suptitle", None) is not None:
self.draw_figure_text(fig, fig._suptitle, text_type="suptitle")
for text in fig.texts:
if text is not getattr(fig, "_suptitle", None):
self.draw_figure_text(fig, text)

for ax in fig.axes:
self.crawl_ax(ax)

Expand Down Expand Up @@ -148,6 +163,20 @@ def crawl_ax(self, ax):
if props['visible']:
self.crawl_legend(ax, legend)

def draw_figure_text(self, fig, text, text_type=None):
"""Process a figure-level matplotlib text object"""
content = text.get_text()
if content:
transform = text.get_transform()
position = text.get_position()
coords, position = self.process_transform(transform, None, fig,
position)
style = utils.get_text_style(text)
self.renderer.draw_figure_text(text=content, position=position,
coordinates=coords,
text_type=text_type,
style=style, mplobj=text)

def crawl_legend(self, ax, legend):
"""
Recursively look through objects in legend children
Expand Down Expand Up @@ -183,7 +212,8 @@ def crawl_legend(self, ax, legend):
def draw_line(self, ax, line, force_trans=None):
"""Process a matplotlib line and call renderer.draw_line"""
coordinates, data = self.process_transform(line.get_transform(),
ax, line.get_xydata(),
ax=ax,
data=line.get_xydata(),
force_trans=force_trans)
linestyle = utils.get_line_style(line)
if (linestyle['dasharray'] is None
Expand All @@ -207,8 +237,9 @@ def draw_text(self, ax, text, force_trans=None, text_type=None):
if content:
transform = text.get_transform()
position = text.get_position()
coords, position = self.process_transform(transform, ax,
position,
coords, position = self.process_transform(transform,
ax=ax,
data=position,
force_trans=force_trans)
style = utils.get_text_style(text)
self.renderer.draw_text(text=content, position=position,
Expand All @@ -221,7 +252,8 @@ def draw_patch(self, ax, patch, force_trans=None):
vertices, pathcodes = utils.SVG_path(patch.get_path())
transform = patch.get_transform()
coordinates, vertices = self.process_transform(transform,
ax, vertices,
ax=ax,
data=vertices,
force_trans=force_trans)
linestyle = utils.get_path_style(patch, fill=patch.get_fill())
self.renderer.draw_path(data=vertices,
Expand All @@ -238,13 +270,14 @@ def draw_collection(self, ax, collection,
offsets, paths) = collection._prepare_points()

offset_coords, offsets = self.process_transform(
transOffset, ax, offsets, force_trans=force_offsettrans)
transOffset, ax=ax, data=offsets, force_trans=force_offsettrans)
path_coords = self.process_transform(
transform, ax, force_trans=force_pathtrans)
transform, ax=ax, force_trans=force_pathtrans)

processed_paths = [utils.SVG_path(path) for path in paths]
processed_paths = [(self.process_transform(
transform, ax, path[0], force_trans=force_pathtrans)[1], path[1])
transform, ax=ax, data=path[0],
force_trans=force_pathtrans)[1], path[1])
for path in processed_paths]

path_transforms = collection.get_transforms()
Expand Down
5 changes: 5 additions & 0 deletions mplexporter/renderers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ def draw_marked_line(self, data, coordinates, linestyle, markerstyle,
if markerstyle is not None:
self.draw_markers(data, coordinates, markerstyle, label, mplobj)

def draw_figure_text(self, text, position, coordinates, style,
text_type=None, mplobj=None):
"""Figure-level text; renderers that care can override."""
pass

def draw_line(self, data, coordinates, style, label, mplobj=None):
"""
Draw a line. By default, draw the line via the draw_path() command.
Expand Down
4 changes: 4 additions & 0 deletions mplexporter/renderers/fake_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def open_legend(self, legend, props):
def close_legend(self, legend):
self.output += " closing legend\n"

def draw_figure_text(self, text, position, coordinates, style,
text_type=None, mplobj=None):
self.output += " draw figure text '{0}' {1}\n".format(text, text_type)

def draw_text(self, text, position, coordinates, style,
text_type=None, mplobj=None):
self.output += " draw text '{0}' {1}\n".format(text, text_type)
Expand Down