|
5 | 5 | arguments, insert common text into docstrings, transform arguments to strings, |
6 | 6 | etc. |
7 | 7 | """ |
8 | | -import functools |
9 | | -import inspect |
10 | | -import os |
11 | 8 | import textwrap |
| 9 | +import functools |
12 | 10 |
|
13 | 11 | import numpy as np |
14 | | -from matplotlib.testing.compare import compare_images |
15 | 12 |
|
16 | | -from ..exceptions import GMTImageComparisonFailure, GMTInvalidInput |
17 | 13 | from .utils import is_nonstr_iter |
| 14 | +from ..exceptions import GMTInvalidInput |
18 | 15 |
|
19 | 16 | COMMON_OPTIONS = { |
20 | 17 | "R": """\ |
@@ -406,98 +403,3 @@ def remove_bools(kwargs): |
406 | 403 | else: |
407 | 404 | new_kwargs[arg] = value |
408 | 405 | return new_kwargs |
409 | | - |
410 | | - |
411 | | -def check_figures_equal(*, tol=0.0, result_dir="result_images"): |
412 | | - """ |
413 | | - Decorator for test cases that generate and compare two figures. |
414 | | -
|
415 | | - The decorated function must take two arguments, *fig_ref* and *fig_test*, |
416 | | - and draw the reference and test images on them. After the function |
417 | | - returns, the figures are saved and compared. |
418 | | -
|
419 | | - This decorator is practically identical to matplotlib's check_figures_equal |
420 | | - function, but adapted for PyGMT figures. See also the original code at |
421 | | - https://matplotlib.org/3.3.1/api/testing_api.html# |
422 | | - matplotlib.testing.decorators.check_figures_equal |
423 | | -
|
424 | | - Parameters |
425 | | - ---------- |
426 | | - tol : float |
427 | | - The RMS threshold above which the test is considered failed. |
428 | | - result_dir : str |
429 | | - The directory where the figures will be stored. |
430 | | -
|
431 | | - Examples |
432 | | - -------- |
433 | | -
|
434 | | - >>> import pytest |
435 | | - >>> @check_figures_equal() |
436 | | - ... def test_check_figures_equal(fig_ref, fig_test): |
437 | | - ... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True) |
438 | | - ... fig_test.basemap(projection="X5c", region=[0, 5, 0, 5], frame="af") |
439 | | - >>> test_check_figures_equal() |
440 | | -
|
441 | | - >>> import shutil |
442 | | - >>> @check_figures_equal(result_dir="tmp_result_images") |
443 | | - ... def test_check_figures_unequal(fig_ref, fig_test): |
444 | | - ... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True) |
445 | | - ... fig_test.basemap(projection="X5c", region=[0, 3, 0, 3], frame=True) |
446 | | - >>> with pytest.raises(GMTImageComparisonFailure): |
447 | | - ... test_check_figures_unequal() |
448 | | - >>> shutil.rmtree(path="tmp_result_images") |
449 | | -
|
450 | | - """ |
451 | | - |
452 | | - def decorator(func): |
453 | | - |
454 | | - os.makedirs(result_dir, exist_ok=True) |
455 | | - old_sig = inspect.signature(func) |
456 | | - |
457 | | - def wrapper(*args, **kwargs): |
458 | | - try: |
459 | | - from ..figure import Figure # pylint: disable=import-outside-toplevel |
460 | | - |
461 | | - fig_ref = Figure() |
462 | | - fig_test = Figure() |
463 | | - func(*args, fig_ref=fig_ref, fig_test=fig_test, **kwargs) |
464 | | - ref_image_path = os.path.join( |
465 | | - result_dir, func.__name__ + "-expected.png" |
466 | | - ) |
467 | | - test_image_path = os.path.join(result_dir, func.__name__ + ".png") |
468 | | - fig_ref.savefig(ref_image_path) |
469 | | - fig_test.savefig(test_image_path) |
470 | | - |
471 | | - # Code below is adapted for PyGMT, and is originally based on |
472 | | - # matplotlib.testing.decorators._raise_on_image_difference |
473 | | - err = compare_images( |
474 | | - expected=ref_image_path, |
475 | | - actual=test_image_path, |
476 | | - tol=tol, |
477 | | - in_decorator=True, |
478 | | - ) |
479 | | - if err is None: # Images are the same |
480 | | - os.remove(ref_image_path) |
481 | | - os.remove(test_image_path) |
482 | | - else: # Images are not the same |
483 | | - for key in ["actual", "expected", "diff"]: |
484 | | - err[key] = os.path.relpath(err[key]) |
485 | | - raise GMTImageComparisonFailure( |
486 | | - "images not close (RMS %(rms).3f):\n\t%(actual)s\n\t%(expected)s " |
487 | | - % err |
488 | | - ) |
489 | | - finally: |
490 | | - del fig_ref |
491 | | - del fig_test |
492 | | - |
493 | | - parameters = [ |
494 | | - param |
495 | | - for param in old_sig.parameters.values() |
496 | | - if param.name not in {"fig_test", "fig_ref"} |
497 | | - ] |
498 | | - new_sig = old_sig.replace(parameters=parameters) |
499 | | - wrapper.__signature__ = new_sig |
500 | | - |
501 | | - return wrapper |
502 | | - |
503 | | - return decorator |
0 commit comments