|
5 | 5 | arguments, insert common text into docstrings, transform arguments to strings,
|
6 | 6 | etc.
|
7 | 7 | """
|
8 |
| -import textwrap |
9 | 8 | import functools
|
| 9 | +import os |
| 10 | +import textwrap |
10 | 11 |
|
11 | 12 | import numpy as np
|
| 13 | +from matplotlib.testing.compare import compare_images |
12 | 14 |
|
| 15 | +from ..exceptions import GMTImageComparisonFailure, GMTInvalidInput |
13 | 16 | from .utils import is_nonstr_iter
|
14 |
| -from ..exceptions import GMTInvalidInput |
15 |
| - |
16 | 17 |
|
17 | 18 | COMMON_OPTIONS = {
|
18 | 19 | "R": """\
|
@@ -404,3 +405,84 @@ def remove_bools(kwargs):
|
404 | 405 | else:
|
405 | 406 | new_kwargs[arg] = value
|
406 | 407 | return new_kwargs
|
| 408 | + |
| 409 | + |
| 410 | +def check_figures_equal(*, result_dir="result_images", tol=0.0): |
| 411 | + """ |
| 412 | + Decorator for test cases that generate and compare two figures. |
| 413 | +
|
| 414 | + The decorated function must take two arguments, *fig_ref* and *fig_test*, |
| 415 | + and draw the reference and test images on them. After the function |
| 416 | + returns, the figures are saved and compared. |
| 417 | +
|
| 418 | + Parameters |
| 419 | + ---------- |
| 420 | + result_dir : str |
| 421 | + The directory where the figures will be stored. |
| 422 | + tol : float |
| 423 | + The RMS threshold above which the test is considered failed. |
| 424 | +
|
| 425 | + Examples |
| 426 | + -------- |
| 427 | +
|
| 428 | + >>> import pytest |
| 429 | + >>> @check_figures_equal() |
| 430 | + ... def test_check_figures_equal(fig_ref, fig_test): |
| 431 | + ... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True) |
| 432 | + ... fig_test.basemap(projection="X5c", region=[0, 5, 0, 5], frame="af") |
| 433 | + >>> test_check_figures_equal() |
| 434 | +
|
| 435 | + >>> import shutil |
| 436 | + >>> @check_figures_equal(result_dir="tmp_result_images") |
| 437 | + ... def test_check_figures_unequal(fig_ref, fig_test): |
| 438 | + ... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True) |
| 439 | + ... fig_test.basemap(projection="X5c", region=[0, 3, 0, 3], frame=True) |
| 440 | + >>> with pytest.raises(GMTImageComparisonFailure): |
| 441 | + ... test_check_figures_unequal() |
| 442 | + >>> shutil.rmtree(path="tmp_result_images") |
| 443 | +
|
| 444 | + """ |
| 445 | + |
| 446 | + def decorator(func): |
| 447 | + |
| 448 | + os.makedirs(result_dir, exist_ok=True) |
| 449 | + |
| 450 | + def wrapper(): |
| 451 | + try: |
| 452 | + from ..figure import Figure # pylint: disable=import-outside-toplevel |
| 453 | + |
| 454 | + fig_ref = Figure() |
| 455 | + fig_test = Figure() |
| 456 | + func(fig_ref, fig_test) |
| 457 | + ref_image_path = os.path.join( |
| 458 | + result_dir, func.__name__ + "-expected.png" |
| 459 | + ) |
| 460 | + test_image_path = os.path.join(result_dir, func.__name__ + ".png") |
| 461 | + fig_ref.savefig(ref_image_path) |
| 462 | + fig_test.savefig(test_image_path) |
| 463 | + |
| 464 | + # Code below is adapted for PyGMT, and is originally based on |
| 465 | + # matplotlib.testing.decorators._raise_on_image_difference |
| 466 | + err = compare_images( |
| 467 | + expected=ref_image_path, |
| 468 | + actual=test_image_path, |
| 469 | + tol=tol, |
| 470 | + in_decorator=True, |
| 471 | + ) |
| 472 | + if err is None: # Images are the same |
| 473 | + os.remove(ref_image_path) |
| 474 | + os.remove(test_image_path) |
| 475 | + else: # Images are not the same |
| 476 | + for key in ["actual", "expected", "diff"]: |
| 477 | + err[key] = os.path.relpath(err[key]) |
| 478 | + raise GMTImageComparisonFailure( |
| 479 | + "images not close (RMS %(rms).3f):\n\t%(actual)s\n\t%(expected)s " |
| 480 | + % err |
| 481 | + ) |
| 482 | + finally: |
| 483 | + del fig_ref |
| 484 | + del fig_test |
| 485 | + |
| 486 | + return wrapper |
| 487 | + |
| 488 | + return decorator |
0 commit comments