diff --git a/inline_snapshot/_pandas.py b/inline_snapshot/_pandas.py index 76149590..4cf05ade 100644 --- a/inline_snapshot/_pandas.py +++ b/inline_snapshot/_pandas.py @@ -1,15 +1,8 @@ from functools import wraps from typing import Optional -from pandas import DataFrame -from pandas import Index -from pandas import Series -from pandas.testing import assert_frame_equal as real_assert_frame_equal -from pandas.testing import assert_index_equal as real_assert_index_equal -from pandas.testing import assert_series_equal as real_assert_series_equal - -def make_assert_equals(data_type, assert_equals, repr_function): +def make_assert_equal(data_type, assert_equal, repr_function): class Wrapper: def __init__(self, df, cmp): @@ -24,14 +17,14 @@ def __eq__(self, other): return NotImplemented return self.cmp(self.df, other) - @wraps(assert_equals) + @wraps(assert_equal) def result(df, df_snapshot, *args, **kargs): error: Optional[AssertionError] = None def cmp(a, b): nonlocal error try: - assert_equals(a, b, *args, **kargs) + assert_equal(a, b, *args, **kargs) except AssertionError as e: error = e return False @@ -44,12 +37,21 @@ def cmp(a, b): return result -assert_frame_equal = make_assert_equals( - DataFrame, real_assert_frame_equal, lambda df: df.to_dict("records") -) -assert_series_equal = make_assert_equals( - Series, real_assert_series_equal, lambda df: df.to_dict() -) -assert_index_equal = make_assert_equals( - Index, real_assert_index_equal, lambda df: df.to_list() -) +try: + import pandas +except: + pass +else: + from pandas.testing import assert_frame_equal + from pandas.testing import assert_index_equal + from pandas.testing import assert_series_equal + + assert_frame_equal = make_assert_equal( + pandas.DataFrame, assert_frame_equal, lambda df: df.to_dict("records") + ) + assert_series_equal = make_assert_equal( + pandas.Series, assert_series_equal, lambda df: df.to_dict() + ) + assert_index_equal = make_assert_equal( + pandas.Index, assert_index_equal, lambda df: df.to_list() + ) diff --git a/tests/test_pandas.py b/tests/test_pandas.py index 3659b7ed..acb963f2 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -1,16 +1,16 @@ import sys import pytest -from pandas import DataFrame -from pandas import Index -from pandas import Series -from inline_snapshot import snapshot -from inline_snapshot._pandas import assert_frame_equal -from inline_snapshot._pandas import assert_index_equal -from inline_snapshot._pandas import assert_series_equal +if sys.version_info >= (3, 9): + from pandas import DataFrame + from pandas import Index + from pandas import Series -nan = float("nan") + from inline_snapshot import snapshot + from inline_snapshot._pandas import assert_frame_equal + from inline_snapshot._pandas import assert_index_equal + from inline_snapshot._pandas import assert_series_equal @pytest.mark.skipif(sys.version_info < (3, 9), reason="no pandas for 3.9")