Skip to content
Merged
6 changes: 6 additions & 0 deletions pygmt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,9 @@ class GMTVersionError(GMTError):
"""
Raised when an incompatible version of GMT is being used.
"""


class GMTImageComparisonFailure(AssertionError):
"""
Raised when a comparison between two images fails.
"""
2 changes: 1 addition & 1 deletion pygmt/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Functions, classes, decorators, and context managers to help wrap GMT modules.
"""
from .decorators import fmt_docstring, use_alias, kwargs_to_strings
from .decorators import check_figures_equal, fmt_docstring, kwargs_to_strings, use_alias
from .tempfile import GMTTempFile, unique_name
from .utils import (
data_kind,
Expand Down
98 changes: 95 additions & 3 deletions pygmt/helpers/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
arguments, insert common text into docstrings, transform arguments to strings,
etc.
"""
import textwrap
import functools
import inspect
import os
import textwrap

import numpy as np
from matplotlib.testing.compare import compare_images
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from matplotlib.testing.compare import compare_images

As I understand it, the code means that now matplotlib becomes a required dependency, even for users who never run the tests, right?

Although PyGMT already requires matplotlib for testings and most users usually have matplotlib installed. I still don't want to add one dependency to PyGMT.

When I wrote the first commit (8b78614), I put the codes in pygmt/helpers/testing.py. By doing that way, I think matplotlib is still optional, although I haven't tested it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I think you're right here. I also encountered issues with circular imports when moving the code to decorators.py, hence this line:

from ..figure import Figure # pylint: disable=import-outside-toplevel

Probably should move it back under pygmt/helpers/testing.py then. As an aside, I've opened up a feature request at matplotlib/pytest-mpl#94, and we might be able to do all this from pytest-mpl in the future.


from ..exceptions import GMTImageComparisonFailure, GMTInvalidInput
from .utils import is_nonstr_iter
from ..exceptions import GMTInvalidInput


COMMON_OPTIONS = {
"R": """\
Expand Down Expand Up @@ -404,3 +406,93 @@ def remove_bools(kwargs):
else:
new_kwargs[arg] = value
return new_kwargs


def check_figures_equal(*, result_dir="result_images", tol=0.0):
"""
Decorator for test cases that generate and compare two figures.

The decorated function must take two arguments, *fig_ref* and *fig_test*,
and draw the reference and test images on them. After the function
returns, the figures are saved and compared.

Parameters
----------
result_dir : str
The directory where the figures will be stored.
tol : float
The RMS threshold above which the test is considered failed.

Examples
--------

>>> import pytest
>>> @check_figures_equal()
... def test_check_figures_equal(fig_ref, fig_test):
... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True)
... fig_test.basemap(projection="X5c", region=[0, 5, 0, 5], frame="af")
>>> test_check_figures_equal()

>>> import shutil
>>> @check_figures_equal(result_dir="tmp_result_images")
... def test_check_figures_unequal(fig_ref, fig_test):
... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True)
... fig_test.basemap(projection="X5c", region=[0, 3, 0, 3], frame=True)
>>> with pytest.raises(GMTImageComparisonFailure):
... test_check_figures_unequal()
>>> shutil.rmtree(path="tmp_result_images")

"""

def decorator(func):

os.makedirs(result_dir, exist_ok=True)
old_sig = inspect.signature(func)

def wrapper(*args, **kwargs):
try:
from ..figure import Figure # pylint: disable=import-outside-toplevel

fig_ref = Figure()
fig_test = Figure()
func(*args, fig_ref=fig_ref, fig_test=fig_test, **kwargs)
ref_image_path = os.path.join(
result_dir, func.__name__ + "-expected.png"
)
test_image_path = os.path.join(result_dir, func.__name__ + ".png")
fig_ref.savefig(ref_image_path)
fig_test.savefig(test_image_path)

# Code below is adapted for PyGMT, and is originally based on
# matplotlib.testing.decorators._raise_on_image_difference
err = compare_images(
expected=ref_image_path,
actual=test_image_path,
tol=tol,
in_decorator=True,
)
if err is None: # Images are the same
os.remove(ref_image_path)
os.remove(test_image_path)
else: # Images are not the same
for key in ["actual", "expected", "diff"]:
err[key] = os.path.relpath(err[key])
raise GMTImageComparisonFailure(
"images not close (RMS %(rms).3f):\n\t%(actual)s\n\t%(expected)s "
% err
)
finally:
del fig_ref
del fig_test

parameters = [
param
for param in old_sig.parameters.values()
if param.name not in {"fig_test", "fig_ref"}
]
new_sig = old_sig.replace(parameters=parameters)
wrapper.__signature__ = new_sig
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Figured out how to make our PyGMT check_figures_equal decorator work with pytest fixtures (e.g. grid=xr.DataArray(...)) in 3e0d3fb. This is basically just copying what was done in matplotlib at matplotlib/matplotlib#16800.


return wrapper

return decorator
14 changes: 12 additions & 2 deletions pygmt/tests/test_grdimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
Test Figure.grdimage
"""
import numpy as np
import xarray as xr
import pytest
import xarray as xr

from .. import Figure
from ..exceptions import GMTInvalidInput
from ..datasets import load_earth_relief
from ..exceptions import GMTInvalidInput
from ..helpers import check_figures_equal


@pytest.fixture(scope="module", name="grid")
Expand Down Expand Up @@ -93,3 +94,12 @@ def test_grdimage_over_dateline(xrgrid):
xrgrid.gmt.gtype = 1 # geographic coordinate system
fig.grdimage(grid=xrgrid, region="g", projection="A0/0/1c", V="i")
return fig


@check_figures_equal()
def test_grdimage_central_longitude(grid, fig_ref, fig_test):
"""
Test that plotting a grid centred at different longitudes/meridians work.
"""
fig_ref.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap="geo")
fig_test.grdimage(grid, projection="W120/15c", cmap="geo")
Comment on lines +99 to +105
Copy link
Member

@weiji14 weiji14 Sep 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@check_figures_equal()
def test_grdimage_central_longitude(grid, fig_ref, fig_test):
"""
Test that plotting a grid centred at different longitudes/meridians work.
"""
fig_ref.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap="geo")
fig_test.grdimage(grid, projection="W120/15c", cmap="geo")
@pytest.mark.parametrize("meridian", [0, 33, 120, 180])
@check_figures_equal()
@pytest.mark.parametrize("proj_type", ["H", "Q", "W"])
def test_grdimage_different_central_meridians_and_projections(
grid, proj_type, meridian, fig_ref, fig_test
):
"""
Test that plotting a grid centred on different meridians using different
projection systems work.
"""
fig_ref.grdimage(
"@earth_relief_01d_g", projection=f"{proj_type}{meridian}/15c", cmap="geo"
)
fig_test.grdimage(grid, projection=f"{proj_type}{meridian}/15c", cmap="geo")

I'll update this test in #560 later 😄. Problem with using this fancy pytest.mark.parametrize is that it would complicate the check_figures_equal code (see matplotlib/matplotlib#15199 and matplotlib/matplotlib#16693), and make this PR even harder to review.