Skip to content

Commit d2ad3f5

Browse files
committed
Turn check_figures_equal into a decorator function
Also moved test_check_figures_* to a doctest under check_figures_equal.
1 parent 27e03ed commit d2ad3f5

File tree

6 files changed

+96
-89
lines changed

6 files changed

+96
-89
lines changed

Diff for: pygmt/helpers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Functions, classes, decorators, and context managers to help wrap GMT modules.
33
"""
4-
from .decorators import fmt_docstring, use_alias, kwargs_to_strings
4+
from .decorators import check_figures_equal, fmt_docstring, kwargs_to_strings, use_alias
55
from .tempfile import GMTTempFile, unique_name
66
from .utils import (
77
data_kind,

Diff for: pygmt/helpers/decorators.py

+85-3
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
arguments, insert common text into docstrings, transform arguments to strings,
66
etc.
77
"""
8-
import textwrap
98
import functools
9+
import os
10+
import textwrap
1011

1112
import numpy as np
13+
from matplotlib.testing.compare import compare_images
1214

15+
from ..exceptions import GMTImageComparisonFailure, GMTInvalidInput
1316
from .utils import is_nonstr_iter
14-
from ..exceptions import GMTInvalidInput
15-
1617

1718
COMMON_OPTIONS = {
1819
"R": """\
@@ -404,3 +405,84 @@ def remove_bools(kwargs):
404405
else:
405406
new_kwargs[arg] = value
406407
return new_kwargs
408+
409+
410+
def check_figures_equal(*, result_dir="result_images", tol=0.0):
411+
"""
412+
Decorator for test cases that generate and compare two figures.
413+
414+
The decorated function must take two arguments, *fig_ref* and *fig_test*,
415+
and draw the reference and test images on them. After the function
416+
returns, the figures are saved and compared.
417+
418+
Parameters
419+
----------
420+
result_dir : str
421+
The directory where the figures will be stored.
422+
tol : float
423+
The RMS threshold above which the test is considered failed.
424+
425+
Examples
426+
--------
427+
428+
>>> import pytest
429+
>>> @check_figures_equal()
430+
... def test_check_figures_equal(fig_ref, fig_test):
431+
... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True)
432+
... fig_test.basemap(projection="X5c", region=[0, 5, 0, 5], frame="af")
433+
>>> test_check_figures_equal()
434+
435+
>>> import shutil
436+
>>> @check_figures_equal(result_dir="tmp_result_images")
437+
... def test_check_figures_unequal(fig_ref, fig_test):
438+
... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True)
439+
... fig_test.basemap(projection="X5c", region=[0, 3, 0, 3], frame=True)
440+
>>> with pytest.raises(GMTImageComparisonFailure):
441+
... test_check_figures_unequal()
442+
>>> shutil.rmtree(path="tmp_result_images")
443+
444+
"""
445+
446+
def decorator(func):
447+
448+
os.makedirs(result_dir, exist_ok=True)
449+
450+
def wrapper():
451+
try:
452+
from ..figure import Figure # pylint: disable=import-outside-toplevel
453+
454+
fig_ref = Figure()
455+
fig_test = Figure()
456+
func(fig_ref, fig_test)
457+
ref_image_path = os.path.join(
458+
result_dir, func.__name__ + "-expected.png"
459+
)
460+
test_image_path = os.path.join(result_dir, func.__name__ + ".png")
461+
fig_ref.savefig(ref_image_path)
462+
fig_test.savefig(test_image_path)
463+
464+
# Code below is adapted for PyGMT, and is originally based on
465+
# matplotlib.testing.decorators._raise_on_image_difference
466+
err = compare_images(
467+
expected=ref_image_path,
468+
actual=test_image_path,
469+
tol=tol,
470+
in_decorator=True,
471+
)
472+
if err is None: # Images are the same
473+
os.remove(ref_image_path)
474+
os.remove(test_image_path)
475+
else: # Images are not the same
476+
for key in ["actual", "expected", "diff"]:
477+
err[key] = os.path.relpath(err[key])
478+
raise GMTImageComparisonFailure(
479+
"images not close (RMS %(rms).3f):\n\t%(actual)s\n\t%(expected)s "
480+
% err
481+
)
482+
finally:
483+
del fig_ref
484+
del fig_test
485+
486+
return wrapper
487+
488+
return decorator

Diff for: pygmt/helpers/testing.py

-37
This file was deleted.

Diff for: pygmt/tests/test.py

-9
This file was deleted.

Diff for: pygmt/tests/test_grdimage.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
Test Figure.grdimage
33
"""
44
import numpy as np
5-
import xarray as xr
65
import pytest
6+
import xarray as xr
77

88
from .. import Figure
9-
from ..exceptions import GMTInvalidInput
109
from ..datasets import load_earth_relief
11-
from ..helpers.testing import check_figures_equal
10+
from ..exceptions import GMTInvalidInput
11+
from ..helpers import check_figures_equal
1212

1313

1414
@pytest.fixture(scope="module", name="grid")
@@ -96,11 +96,10 @@ def test_grdimage_over_dateline(xrgrid):
9696
return fig
9797

9898

99-
def test_grdimage_central_longitude(grid):
100-
fig1 = Figure()
101-
fig1.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap="geo")
102-
103-
fig2 = Figure()
104-
fig2.grdimage(grid, projection="W120/15c", cmap="geo")
105-
106-
check_figures_equal(fig1, fig2)
99+
@check_figures_equal()
100+
def test_grdimage_central_longitude(grid, fig_ref, fig_test):
101+
"""
102+
Test that plotting a grid centred at different longitudes/meridians work.
103+
"""
104+
fig_ref.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap="geo")
105+
fig_test.grdimage(grid, projection="W120/15c", cmap="geo")

Diff for: pygmt/tests/test_testing.py

-28
This file was deleted.

0 commit comments

Comments
 (0)