Skip to content

Commit 5cc40a9

Browse files
authored
SNOW-1247349: Add snowflake.snowpark.testing.assert_dataframe_equal (#2010)
1 parent 73b07d1 commit 5cc40a9

File tree

5 files changed

+508
-0
lines changed

5 files changed

+508
-0
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
### Snowpark Python API Updates
66

7+
#### New Features
8+
- Added support for `snowflake.snowpark.testing.assert_dataframe_equal` that is a util function to check the equality of two Snowpark DataFrames.
9+
710
#### Improvements
811
- Added support server side string size limitations.
912
- Added support for column lineage in the DataFrame.lineage.trace API.

docs/source/snowpark/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@ Snowpark APIs
2424
lineage
2525
context
2626
exceptions
27+
testing
2728

2829
:ref:`genindex`

docs/source/snowpark/testing.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
=============
2+
Testing
3+
=============
4+
Testing module for Snowpark.
5+
6+
.. currentmodule:: snowflake.snowpark.testing
7+
8+
9+
.. autosummary::
10+
:toctree: api/
11+
12+
assert_dataframe_equal
13+
assertDataFrameEqual

src/snowflake/snowpark/testing.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
#
2+
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
import difflib
6+
import functools
7+
import math
8+
from typing import List
9+
10+
from snowflake.snowpark._internal.utils import experimental
11+
from snowflake.snowpark.dataframe import DataFrame
12+
from snowflake.snowpark.row import Row
13+
from snowflake.snowpark.types import StructType, _FractionalType, _IntegralType
14+
15+
ACTUAL_EXPECTED_STRING = "--- actual ---\n+++ expected +++"
16+
17+
18+
def _get_sorted_rows(rows: List[Row]) -> List[Row]:
19+
def compare_rows(row1, row2):
20+
for value1, value2 in zip(row1, row2):
21+
if value1 == value2:
22+
continue
23+
if value1 is None:
24+
return -1
25+
elif value2 is None:
26+
return 1
27+
elif value1 > value2:
28+
return 1
29+
elif value1 < value2:
30+
return -1
31+
return 0
32+
33+
sort_key = functools.cmp_to_key(compare_rows)
34+
return sorted(rows, key=sort_key)
35+
36+
37+
def _assert_schema_equal(
38+
actual: StructType,
39+
expected: StructType,
40+
):
41+
"""
42+
Asserts whether two :class:`types.StructType` objects are the same.
43+
"""
44+
assert len(actual.fields) == len(
45+
expected.fields
46+
), f"Different number of columns: actual has {len(actual.fields)} columns, expected has {len(expected.fields)} columns"
47+
48+
for column_index, (actual_field, expected_field) in enumerate(
49+
zip(actual.fields, expected.fields)
50+
):
51+
error_message = None
52+
if actual_field.name != expected_field.name:
53+
error_message = f"Column name mismatch at column {column_index}: actual {actual_field.name}, expected {expected_field.name}"
54+
if actual_field.datatype != expected_field.datatype:
55+
if not (
56+
(
57+
isinstance(actual_field.datatype, _IntegralType)
58+
and isinstance(expected_field, _IntegralType)
59+
)
60+
or (
61+
isinstance(actual_field.datatype, _FractionalType)
62+
and isinstance(expected_field, _FractionalType)
63+
)
64+
):
65+
error_message = f"Column data type mismatch at column {column_index}: actual {actual_field.datatype}, expected {expected_field.datatype}"
66+
if actual_field.nullable != expected_field.nullable:
67+
error_message = f"Column nullable mismatch at column {column_index}: actual {actual_field.nullable}, expected {expected_field.nullable}"
68+
if error_message:
69+
actual_str = str(actual)
70+
expected_str = str(expected)
71+
if actual_str != expected_str:
72+
diff = difflib.ndiff(actual_str.splitlines(), expected_str.splitlines())
73+
diff_str = "\n".join(diff)
74+
raise AssertionError(
75+
f"{error_message}\nDifferent schema:\n{ACTUAL_EXPECTED_STRING}\n{diff_str}"
76+
)
77+
78+
79+
@experimental(version="1.21.0")
80+
def assert_dataframe_equal(
81+
actual: DataFrame,
82+
expected: DataFrame,
83+
rtol: float = 1e-5,
84+
atol: float = 1e-8,
85+
) -> None:
86+
"""
87+
Asserts that two Snowpark :class:`DataFrame` objects are equal. This function compares both the schema and the data
88+
of the DataFrames. If there are differences, an ``AssertionError`` is raised with a detailed message including differences.
89+
This function is useful for unit testing and validating data transformations and processing in Snowpark.
90+
91+
Args:
92+
actual: The actual DataFrame to be compared.
93+
expected: The expected DataFrame to compare against.
94+
rtol: The relative tolerance for comparing float values. Default is 1e-5.
95+
atol: The absolute tolerance for comparing float values. Default is 1e-8.
96+
97+
Examples::
98+
99+
>>> from snowflake.snowpark.testing import assert_dataframe_equal
100+
>>> from snowflake.snowpark.types import StructType, StructField, IntegerType, StringType, DoubleType
101+
>>> schema1 = StructType([
102+
... StructField("id", IntegerType()),
103+
... StructField("name", StringType()),
104+
... StructField("value", DoubleType())
105+
... ])
106+
>>> data1 = [[1, "Rice", 1.0], [2, "Saka", 2.0], [3, "White", 3.0]]
107+
>>> df1 = session.create_dataframe(data1, schema1)
108+
>>> df2 = session.create_dataframe(data1, schema1)
109+
>>> assert_dataframe_equal(df2, df1) # pass, DataFrames are identical
110+
111+
>>> data2 = [[2, "Saka", 2.0], [1, "Rice", 1.0], [3, "White", 3.0]] # change the order
112+
>>> df3 = session.create_dataframe(data2, schema1)
113+
>>> assert_dataframe_equal(df3, df1) # pass, DataFrames are identical
114+
115+
>>> data3 = [[1, "Rice", 1.0], [2, "Saka", 2.0], [4, "Rowe", 4.0]]
116+
>>> df4 = session.create_dataframe(data3, schema1)
117+
>>> assert_dataframe_equal(df4, df1) # doctest: +IGNORE_EXCEPTION_DETAIL
118+
Traceback (most recent call last):
119+
AssertionError: Value mismatch on row 2 at column 0: actual 4, expected 3
120+
Different row:
121+
--- actual ---
122+
+++ expected +++
123+
- Row(ID=4, NAME='Rowe', VALUE=4.0)
124+
? ^ ^^^ ^
125+
126+
+ Row(ID=3, NAME='White', VALUE=3.0)
127+
? ^ ^^^^ ^
128+
129+
>>> data4 = [[1, "Rice", 1.0], [2, "Saka", 2.0], [3, "White", 3.0001]]
130+
>>> df5 = session.create_dataframe(data4, schema1)
131+
>>> assert_dataframe_equal(df5, df1, atol=1e-3) # pass, DataFrames are identical due to higher error tolerance
132+
>>> assert_dataframe_equal(df5, df1, atol=1e-5) # doctest: +IGNORE_EXCEPTION_DETAIL
133+
Traceback (most recent call last):
134+
AssertionError: Value mismatch on row 2 at column 2: actual 3.0001, expected 3.0
135+
Different row:
136+
--- actual ---
137+
+++ expected +++
138+
- Row(ID=3, NAME='White', VALUE=3.0001)
139+
? ---
140+
141+
+ Row(ID=3, NAME='White', VALUE=3.0)
142+
143+
>>> schema2 = StructType([
144+
... StructField("id", IntegerType()),
145+
... StructField("key", StringType()),
146+
... StructField("value", DoubleType())
147+
... ])
148+
>>> df6 = session.create_dataframe(data1, schema2)
149+
>>> assert_dataframe_equal(df6, df1) # doctest: +IGNORE_EXCEPTION_DETAIL
150+
Traceback (most recent call last):
151+
AssertionError: Column name mismatch at column 1: actual KEY, expected NAME
152+
Different schema:
153+
--- actual ---
154+
+++ expected +++
155+
- StructType([StructField('ID', LongType(), nullable=True), StructField('KEY', StringType(), nullable=True), StructField('VALUE', DoubleType(), nullable=True)])
156+
? ^ -
157+
158+
+ StructType([StructField('ID', LongType(), nullable=True), StructField('NAME', StringType(), nullable=True), StructField('VALUE', DoubleType(), nullable=True)])
159+
?
160+
161+
>>> schema3 = StructType([
162+
... StructField("id", IntegerType()),
163+
... StructField("name", StringType()),
164+
... StructField("value", IntegerType())
165+
... ])
166+
>>> df7 = session.create_dataframe(data1, schema3)
167+
>>> assert_dataframe_equal(df7, df1) # doctest: +IGNORE_EXCEPTION_DETAIL
168+
Traceback (most recent call last):
169+
AssertionError: Column data type mismatch at column 2: actual LongType(), expected DoubleType()
170+
Different schema:
171+
--- actual ---
172+
+++ expected +++
173+
- StructType([StructField('ID', LongType(), nullable=True), StructField('NAME', StringType(), nullable=True), StructField('VALUE', LongType(), nullable=True)])
174+
? ^ ^^
175+
176+
+ StructType([StructField('ID', LongType(), nullable=True), StructField('NAME', StringType(), nullable=True), StructField('VALUE', DoubleType(), nullable=True)])
177+
?
178+
179+
Note:
180+
1. Data in a Snowpark DataFrame is unordered, so when comparing two DataFrames, this function
181+
sorts rows based on their values first.
182+
183+
2. When comparing schemas, :class:`types.IntegerType` and :class:`types.DoubleType` are considered different,
184+
even if the underlying values are equal (e.g., 2 vs 2.0).
185+
186+
"""
187+
if not isinstance(actual, DataFrame):
188+
raise TypeError("actual must be a Snowpark DataFrame")
189+
if not isinstance(expected, DataFrame):
190+
raise TypeError("expected must be a Snowpark DataFrame")
191+
192+
actual_schema = actual.schema
193+
expected_schema = expected.schema
194+
_assert_schema_equal(actual_schema, expected_schema)
195+
196+
actual_rows = _get_sorted_rows(actual.collect())
197+
expected_rows = _get_sorted_rows(expected.collect())
198+
assert len(actual_rows) == len(
199+
expected_rows
200+
), f"Different number of rows: actual has {len(actual_rows)} rows, expected has {len(expected_rows)} rows"
201+
202+
for row_index, (actual_row, expected_row) in enumerate(
203+
zip(actual_rows, expected_rows)
204+
):
205+
for column_index, (actual_value, expected_value) in enumerate(
206+
zip(actual_row, expected_row)
207+
):
208+
error_message = f"Value mismatch on row {row_index} at column {column_index}: actual {actual_value}, expected {expected_value}"
209+
failed = False
210+
if isinstance(expected_value, float):
211+
if math.isnan(actual_value) != math.isnan(expected_value):
212+
failed = True
213+
if not math.isclose(
214+
actual_value, expected_value, rel_tol=rtol, abs_tol=atol
215+
):
216+
failed = True
217+
else:
218+
failed = bool(actual_value != expected_value)
219+
if failed:
220+
actual_row_str = str(actual_row)
221+
expected_row_str = str(expected_row)
222+
if actual_row_str != expected_row_str:
223+
diff = difflib.ndiff(
224+
actual_row_str.splitlines(), expected_row_str.splitlines()
225+
)
226+
diff_str = "\n".join(diff)
227+
raise AssertionError(
228+
f"{error_message}\nDifferent row:\n{ACTUAL_EXPECTED_STRING}\n{diff_str}"
229+
)
230+
231+
232+
assertDataFrameEqual = assert_dataframe_equal

0 commit comments

Comments
 (0)