Skip to content
Merged
6 changes: 6 additions & 0 deletions pygmt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,9 @@ class GMTVersionError(GMTError):
"""
Raised when an incompatible version of GMT is being used.
"""


class GMTImageComparisonFailure(AssertionError):
"""
Raised when a comparison between two images fails.
"""
37 changes: 37 additions & 0 deletions pygmt/helpers/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import sys
from pathlib import Path

from matplotlib.testing.compare import compare_images

from ..exceptions import GMTImageComparisonFailure, GMTInvalidInput


def check_figures_equal(fig_ref, fig_test, fig_prefix=None, tol=0.0):
result_dir = "result_images"

if not fig_prefix:
try:
fig_prefix = sys._getframe(1).f_code.co_name
except ValueError:
raise GMTInvalidInput("fig_prefix is required.")

os.makedirs(result_dir, exist_ok=True)

ref_image_path = os.path.join(result_dir, fig_prefix + "-expected.png")
test_image_path = os.path.join(result_dir, fig_prefix + ".png")

fig_ref.savefig(ref_image_path)
fig_test.savefig(test_image_path)

err = compare_images(ref_image_path, test_image_path, tol, in_decorator=True)

if err is None: # Images are the same
os.remove(ref_image_path)
os.remove(test_image_path)
else:
for key in ["actual", "expected"]:
err[key] = os.path.relpath(err[key])
raise GMTImageComparisonFailure(
"images not close (RMS %(rms).3f):\n\t%(actual)s\n\t%(expected)s " % err
)
9 changes: 9 additions & 0 deletions pygmt/tests/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from matplotlib.testing.decorators import check_figures_equal

import pygmt


@check_figures_equal(extensions=["png"])
def test_plot(fig_test, fig_ref):
fig_test.subplots().plot([1, 3, 5])
fig_ref.subplots().plot([0, 1, 2], [1, 3, 5])
11 changes: 11 additions & 0 deletions pygmt/tests/test_grdimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .. import Figure
from ..exceptions import GMTInvalidInput
from ..datasets import load_earth_relief
from ..helpers.testing import check_figures_equal


@pytest.fixture(scope="module", name="grid")
Expand Down Expand Up @@ -93,3 +94,13 @@ def test_grdimage_over_dateline(xrgrid):
xrgrid.gmt.gtype = 1 # geographic coordinate system
fig.grdimage(grid=xrgrid, region="g", projection="A0/0/1c", V="i")
return fig


def test_grdimage_central_longitude(grid):
fig1 = Figure()
fig1.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap="geo")

fig2 = Figure()
fig2.grdimage(grid, projection="W120/15c", cmap="geo")

check_figures_equal(fig1, fig2)
28 changes: 28 additions & 0 deletions pygmt/tests/test_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Test the testing functions for PyGMT
"""
import pytest

from .. import Figure
from ..exceptions import GMTImageComparisonFailure
from ..helpers.testing import check_figures_equal


def test_check_figures_equal():
fig_ref = Figure()
fig_ref.basemap(projection="X10c", region=[0, 10, 0, 10], frame=True)

fig_test = Figure()
fig_test.basemap(projection="X10c", region=[0, 10, 0, 10], frame=True)
check_figures_equal(fig_ref, fig_test)


def test_check_figures_unequal():
fig_ref = Figure()
fig_ref.basemap(projection="X10c", region=[0, 10, 0, 10], frame=True)

fig_test = Figure()
fig_test.basemap(projection="X10c", region=[0, 15, 0, 15], frame=True)

with pytest.raises(GMTImageComparisonFailure):
check_figures_equal(fig_ref, fig_test)