|
| 1 | +# |
| 2 | +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. |
| 3 | +# |
| 4 | + |
| 5 | +import difflib |
| 6 | +import functools |
| 7 | +import math |
| 8 | +from typing import List |
| 9 | + |
| 10 | +from snowflake.snowpark._internal.utils import experimental |
| 11 | +from snowflake.snowpark.dataframe import DataFrame |
| 12 | +from snowflake.snowpark.row import Row |
| 13 | +from snowflake.snowpark.types import StructType, _FractionalType, _IntegralType |
| 14 | + |
| 15 | +ACTUAL_EXPECTED_STRING = "--- actual ---\n+++ expected +++" |
| 16 | + |
| 17 | + |
| 18 | +def _get_sorted_rows(rows: List[Row]) -> List[Row]: |
| 19 | + def compare_rows(row1, row2): |
| 20 | + for value1, value2 in zip(row1, row2): |
| 21 | + if value1 == value2: |
| 22 | + continue |
| 23 | + if value1 is None: |
| 24 | + return -1 |
| 25 | + elif value2 is None: |
| 26 | + return 1 |
| 27 | + elif value1 > value2: |
| 28 | + return 1 |
| 29 | + elif value1 < value2: |
| 30 | + return -1 |
| 31 | + return 0 |
| 32 | + |
| 33 | + sort_key = functools.cmp_to_key(compare_rows) |
| 34 | + return sorted(rows, key=sort_key) |
| 35 | + |
| 36 | + |
| 37 | +def _assert_schema_equal( |
| 38 | + actual: StructType, |
| 39 | + expected: StructType, |
| 40 | +): |
| 41 | + """ |
| 42 | + Asserts whether two :class:`types.StructType` objects are the same. |
| 43 | + """ |
| 44 | + assert len(actual.fields) == len( |
| 45 | + expected.fields |
| 46 | + ), f"Different number of columns: actual has {len(actual.fields)} columns, expected has {len(expected.fields)} columns" |
| 47 | + |
| 48 | + for column_index, (actual_field, expected_field) in enumerate( |
| 49 | + zip(actual.fields, expected.fields) |
| 50 | + ): |
| 51 | + error_message = None |
| 52 | + if actual_field.name != expected_field.name: |
| 53 | + error_message = f"Column name mismatch at column {column_index}: actual {actual_field.name}, expected {expected_field.name}" |
| 54 | + if actual_field.datatype != expected_field.datatype: |
| 55 | + if not ( |
| 56 | + ( |
| 57 | + isinstance(actual_field.datatype, _IntegralType) |
| 58 | + and isinstance(expected_field, _IntegralType) |
| 59 | + ) |
| 60 | + or ( |
| 61 | + isinstance(actual_field.datatype, _FractionalType) |
| 62 | + and isinstance(expected_field, _FractionalType) |
| 63 | + ) |
| 64 | + ): |
| 65 | + error_message = f"Column data type mismatch at column {column_index}: actual {actual_field.datatype}, expected {expected_field.datatype}" |
| 66 | + if actual_field.nullable != expected_field.nullable: |
| 67 | + error_message = f"Column nullable mismatch at column {column_index}: actual {actual_field.nullable}, expected {expected_field.nullable}" |
| 68 | + if error_message: |
| 69 | + actual_str = str(actual) |
| 70 | + expected_str = str(expected) |
| 71 | + if actual_str != expected_str: |
| 72 | + diff = difflib.ndiff(actual_str.splitlines(), expected_str.splitlines()) |
| 73 | + diff_str = "\n".join(diff) |
| 74 | + raise AssertionError( |
| 75 | + f"{error_message}\nDifferent schema:\n{ACTUAL_EXPECTED_STRING}\n{diff_str}" |
| 76 | + ) |
| 77 | + |
| 78 | + |
| 79 | +@experimental(version="1.21.0") |
| 80 | +def assert_dataframe_equal( |
| 81 | + actual: DataFrame, |
| 82 | + expected: DataFrame, |
| 83 | + rtol: float = 1e-5, |
| 84 | + atol: float = 1e-8, |
| 85 | +) -> None: |
| 86 | + """ |
| 87 | + Asserts that two Snowpark :class:`DataFrame` objects are equal. This function compares both the schema and the data |
| 88 | + of the DataFrames. If there are differences, an ``AssertionError`` is raised with a detailed message including differences. |
| 89 | + This function is useful for unit testing and validating data transformations and processing in Snowpark. |
| 90 | +
|
| 91 | + Args: |
| 92 | + actual: The actual DataFrame to be compared. |
| 93 | + expected: The expected DataFrame to compare against. |
| 94 | + rtol: The relative tolerance for comparing float values. Default is 1e-5. |
| 95 | + atol: The absolute tolerance for comparing float values. Default is 1e-8. |
| 96 | +
|
| 97 | + Examples:: |
| 98 | +
|
| 99 | + >>> from snowflake.snowpark.testing import assert_dataframe_equal |
| 100 | + >>> from snowflake.snowpark.types import StructType, StructField, IntegerType, StringType, DoubleType |
| 101 | + >>> schema1 = StructType([ |
| 102 | + ... StructField("id", IntegerType()), |
| 103 | + ... StructField("name", StringType()), |
| 104 | + ... StructField("value", DoubleType()) |
| 105 | + ... ]) |
| 106 | + >>> data1 = [[1, "Rice", 1.0], [2, "Saka", 2.0], [3, "White", 3.0]] |
| 107 | + >>> df1 = session.create_dataframe(data1, schema1) |
| 108 | + >>> df2 = session.create_dataframe(data1, schema1) |
| 109 | + >>> assert_dataframe_equal(df2, df1) # pass, DataFrames are identical |
| 110 | +
|
| 111 | + >>> data2 = [[2, "Saka", 2.0], [1, "Rice", 1.0], [3, "White", 3.0]] # change the order |
| 112 | + >>> df3 = session.create_dataframe(data2, schema1) |
| 113 | + >>> assert_dataframe_equal(df3, df1) # pass, DataFrames are identical |
| 114 | +
|
| 115 | + >>> data3 = [[1, "Rice", 1.0], [2, "Saka", 2.0], [4, "Rowe", 4.0]] |
| 116 | + >>> df4 = session.create_dataframe(data3, schema1) |
| 117 | + >>> assert_dataframe_equal(df4, df1) # doctest: +IGNORE_EXCEPTION_DETAIL |
| 118 | + Traceback (most recent call last): |
| 119 | + AssertionError: Value mismatch on row 2 at column 0: actual 4, expected 3 |
| 120 | + Different row: |
| 121 | + --- actual --- |
| 122 | + +++ expected +++ |
| 123 | + - Row(ID=4, NAME='Rowe', VALUE=4.0) |
| 124 | + ? ^ ^^^ ^ |
| 125 | +
|
| 126 | + + Row(ID=3, NAME='White', VALUE=3.0) |
| 127 | + ? ^ ^^^^ ^ |
| 128 | +
|
| 129 | + >>> data4 = [[1, "Rice", 1.0], [2, "Saka", 2.0], [3, "White", 3.0001]] |
| 130 | + >>> df5 = session.create_dataframe(data4, schema1) |
| 131 | + >>> assert_dataframe_equal(df5, df1, atol=1e-3) # pass, DataFrames are identical due to higher error tolerance |
| 132 | + >>> assert_dataframe_equal(df5, df1, atol=1e-5) # doctest: +IGNORE_EXCEPTION_DETAIL |
| 133 | + Traceback (most recent call last): |
| 134 | + AssertionError: Value mismatch on row 2 at column 2: actual 3.0001, expected 3.0 |
| 135 | + Different row: |
| 136 | + --- actual --- |
| 137 | + +++ expected +++ |
| 138 | + - Row(ID=3, NAME='White', VALUE=3.0001) |
| 139 | + ? --- |
| 140 | +
|
| 141 | + + Row(ID=3, NAME='White', VALUE=3.0) |
| 142 | +
|
| 143 | + >>> schema2 = StructType([ |
| 144 | + ... StructField("id", IntegerType()), |
| 145 | + ... StructField("key", StringType()), |
| 146 | + ... StructField("value", DoubleType()) |
| 147 | + ... ]) |
| 148 | + >>> df6 = session.create_dataframe(data1, schema2) |
| 149 | + >>> assert_dataframe_equal(df6, df1) # doctest: +IGNORE_EXCEPTION_DETAIL |
| 150 | + Traceback (most recent call last): |
| 151 | + AssertionError: Column name mismatch at column 1: actual KEY, expected NAME |
| 152 | + Different schema: |
| 153 | + --- actual --- |
| 154 | + +++ expected +++ |
| 155 | + - StructType([StructField('ID', LongType(), nullable=True), StructField('KEY', StringType(), nullable=True), StructField('VALUE', DoubleType(), nullable=True)]) |
| 156 | + ? ^ - |
| 157 | +
|
| 158 | + + StructType([StructField('ID', LongType(), nullable=True), StructField('NAME', StringType(), nullable=True), StructField('VALUE', DoubleType(), nullable=True)]) |
| 159 | + ? |
| 160 | +
|
| 161 | + >>> schema3 = StructType([ |
| 162 | + ... StructField("id", IntegerType()), |
| 163 | + ... StructField("name", StringType()), |
| 164 | + ... StructField("value", IntegerType()) |
| 165 | + ... ]) |
| 166 | + >>> df7 = session.create_dataframe(data1, schema3) |
| 167 | + >>> assert_dataframe_equal(df7, df1) # doctest: +IGNORE_EXCEPTION_DETAIL |
| 168 | + Traceback (most recent call last): |
| 169 | + AssertionError: Column data type mismatch at column 2: actual LongType(), expected DoubleType() |
| 170 | + Different schema: |
| 171 | + --- actual --- |
| 172 | + +++ expected +++ |
| 173 | + - StructType([StructField('ID', LongType(), nullable=True), StructField('NAME', StringType(), nullable=True), StructField('VALUE', LongType(), nullable=True)]) |
| 174 | + ? ^ ^^ |
| 175 | +
|
| 176 | + + StructType([StructField('ID', LongType(), nullable=True), StructField('NAME', StringType(), nullable=True), StructField('VALUE', DoubleType(), nullable=True)]) |
| 177 | + ? |
| 178 | +
|
| 179 | + Note: |
| 180 | + 1. Data in a Snowpark DataFrame is unordered, so when comparing two DataFrames, this function |
| 181 | + sorts rows based on their values first. |
| 182 | +
|
| 183 | + 2. When comparing schemas, :class:`types.IntegerType` and :class:`types.DoubleType` are considered different, |
| 184 | + even if the underlying values are equal (e.g., 2 vs 2.0). |
| 185 | +
|
| 186 | + """ |
| 187 | + if not isinstance(actual, DataFrame): |
| 188 | + raise TypeError("actual must be a Snowpark DataFrame") |
| 189 | + if not isinstance(expected, DataFrame): |
| 190 | + raise TypeError("expected must be a Snowpark DataFrame") |
| 191 | + |
| 192 | + actual_schema = actual.schema |
| 193 | + expected_schema = expected.schema |
| 194 | + _assert_schema_equal(actual_schema, expected_schema) |
| 195 | + |
| 196 | + actual_rows = _get_sorted_rows(actual.collect()) |
| 197 | + expected_rows = _get_sorted_rows(expected.collect()) |
| 198 | + assert len(actual_rows) == len( |
| 199 | + expected_rows |
| 200 | + ), f"Different number of rows: actual has {len(actual_rows)} rows, expected has {len(expected_rows)} rows" |
| 201 | + |
| 202 | + for row_index, (actual_row, expected_row) in enumerate( |
| 203 | + zip(actual_rows, expected_rows) |
| 204 | + ): |
| 205 | + for column_index, (actual_value, expected_value) in enumerate( |
| 206 | + zip(actual_row, expected_row) |
| 207 | + ): |
| 208 | + error_message = f"Value mismatch on row {row_index} at column {column_index}: actual {actual_value}, expected {expected_value}" |
| 209 | + failed = False |
| 210 | + if isinstance(expected_value, float): |
| 211 | + if math.isnan(actual_value) != math.isnan(expected_value): |
| 212 | + failed = True |
| 213 | + if not math.isclose( |
| 214 | + actual_value, expected_value, rel_tol=rtol, abs_tol=atol |
| 215 | + ): |
| 216 | + failed = True |
| 217 | + else: |
| 218 | + failed = bool(actual_value != expected_value) |
| 219 | + if failed: |
| 220 | + actual_row_str = str(actual_row) |
| 221 | + expected_row_str = str(expected_row) |
| 222 | + if actual_row_str != expected_row_str: |
| 223 | + diff = difflib.ndiff( |
| 224 | + actual_row_str.splitlines(), expected_row_str.splitlines() |
| 225 | + ) |
| 226 | + diff_str = "\n".join(diff) |
| 227 | + raise AssertionError( |
| 228 | + f"{error_message}\nDifferent row:\n{ACTUAL_EXPECTED_STRING}\n{diff_str}" |
| 229 | + ) |
| 230 | + |
| 231 | + |
| 232 | +assertDataFrameEqual = assert_dataframe_equal |
0 commit comments