From e93cd6821d8accf69cbe95c5bb171838de844112 Mon Sep 17 00:00:00 2001 From: Varnika Budati Date: Mon, 16 Sep 2024 18:09:18 -0700 Subject: [PATCH] SNOW-1661142 Fix index name behavior (#2274) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1661142 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. 3. Please describe how your code solves the related issue. Fixed a bug where updating an index's name updates the parent's index name when it is not supposed to. This is done by verifying that the query_compiler recorded during the index's creation matches that of its parent object when the parent object must be updated. ```py >>> df = pd.DataFrame( ... { ... "A": [0, 1, 2, 3, 4, 4], ... "B": ['a', 'b', 'c', 'd', 'e', 'f'], ... }, ... index = pd.Index([1, 2, 3, 4, 5, 6], name = "test"), ... ) >>> index = df.index >>> df A B test 1 0 a 2 1 b 3 2 c 4 3 d 5 4 e 6 4 f >>> index.name = "test2" >>> >>> df A B test2 1 0 a 2 1 b 3 2 c 4 3 d 5 4 e 6 4 f >>> df.dropna(inplace=True) >>> index.name = "test3" >>> df A B test2 # <--- name should not update 1 0 a 2 1 b 3 2 c 4 3 d 5 4 e 6 4 f ``` For the full discussion, see thread: https://docs.google.com/document/d/1vdllzNgeUHMiffFNpm9SD1HOYUk8lkMVp14HQDoqr7s/edit?disco=AAABVbKjFJ0 --- CHANGELOG.md | 4 ++ .../snowpark/modin/plugin/extensions/index.py | 61 +++++++++++++---- .../index/test_datetime_index_methods.py | 4 +- tests/integ/modin/index/test_index_methods.py | 4 +- tests/integ/modin/index/test_name.py | 66 +++++++++++++++++++ 5 files changed, 122 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0bd719dcb8c..7048a3b728a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ - Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`. - Added support for `by`, `left_by`, and `right_by` for `pd.merge_asof`. +#### Bug Fixes + +- Fixed a bug where an `Index` object created from a `Series`/`DataFrame` incorrectly updates the `Series`/`DataFrame`'s index name after an inplace update has been applied to the original `Series`/`DataFrame`. + ## 1.22.1 (2024-09-11) This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content. diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index 643f6f5038e..b25bb481dc0 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -71,6 +71,35 @@ } +class IndexParent: + def __init__(self, parent: DataFrame | Series) -> None: + """ + Initialize the IndexParent object. + + IndexParent is used to keep track of the parent object that the Index is a part of. + It tracks the parent object and the parent object's query compiler at the time of creation. + + Parameters + ---------- + parent : DataFrame or Series + The parent object that the Index is a part of. + """ + assert isinstance(parent, (DataFrame, Series)) + self._parent = parent + self._parent_qc = parent._query_compiler + + def check_and_update_parent_qc_index_names(self, names: list) -> None: + """ + Update the Index and its parent's index names if the query compiler associated with the parent is + different from the original query compiler recorded, i.e., an inplace update has been applied to the parent. + """ + if self._parent._query_compiler is self._parent_qc: + new_query_compiler = self._parent_qc.set_index_names(names) + self._parent._update_inplace(new_query_compiler=new_query_compiler) + # Update the query compiler after naming operation. + self._parent_qc = new_query_compiler + + class Index(metaclass=TelemetryMeta): # Equivalent index type in native pandas @@ -135,7 +164,7 @@ def __new__( index = object.__new__(cls) # Initialize the Index index._query_compiler = query_compiler - # `_parent` keeps track of any Series or DataFrame that this Index is a part of. + # `_parent` keeps track of the parent object that this Index is a part of. index._parent = None return index @@ -252,6 +281,17 @@ def __getattr__(self, key: str) -> Any: ErrorMessage.not_implemented(f"Index.{key} is not yet implemented") raise err + def _set_parent(self, parent: Series | DataFrame) -> None: + """ + Set the parent object and its query compiler. + + Parameters + ---------- + parent : Series or DataFrame + The parent object that the Index is a part of. + """ + self._parent = IndexParent(parent) + def _binary_ops(self, method: str, other: Any) -> Index: if isinstance(other, Index): other = other.to_series().reset_index(drop=True) @@ -408,12 +448,6 @@ def __constructor__(self): """ return type(self) - def _set_parent(self, parent: Series | DataFrame): - """ - Set the parent object of the current Index to a given Series or DataFrame. - """ - self._parent = parent - @property def values(self) -> ArrayLike: """ @@ -726,10 +760,11 @@ def name(self, value: Hashable) -> None: if not is_hashable(value): raise TypeError(f"{type(self).__name__}.name must be a hashable type") self._query_compiler = self._query_compiler.set_index_names([value]) + # Update the name of the parent's index only if an inplace update is performed on + # the parent object, i.e., the parent's current query compiler matches the originally + # recorded query compiler. if self._parent is not None: - self._parent._update_inplace( - new_query_compiler=self._parent._query_compiler.set_index_names([value]) - ) + self._parent.check_and_update_parent_qc_index_names([value]) def _get_names(self) -> list[Hashable]: """ @@ -755,10 +790,10 @@ def _set_names(self, values: list) -> None: if isinstance(values, Index): values = values.to_list() self._query_compiler = self._query_compiler.set_index_names(values) + # Update the name of the parent's index only if the parent's current query compiler + # matches the recorded query compiler. if self._parent is not None: - self._parent._update_inplace( - new_query_compiler=self._parent._query_compiler.set_index_names(values) - ) + self._parent.check_and_update_parent_qc_index_names(values) names = property(fset=_set_names, fget=_get_names) diff --git a/tests/integ/modin/index/test_datetime_index_methods.py b/tests/integ/modin/index/test_datetime_index_methods.py index 793485f97d6..98d1a041c3b 100644 --- a/tests/integ/modin/index/test_datetime_index_methods.py +++ b/tests/integ/modin/index/test_datetime_index_methods.py @@ -142,13 +142,13 @@ def test_index_parent(): # DataFrame case. df = pd.DataFrame({"A": [1]}, index=native_idx1) snow_idx = df.index - assert_frame_equal(snow_idx._parent, df) + assert_frame_equal(snow_idx._parent._parent, df) assert_index_equal(snow_idx, native_idx1) # Series case. s = pd.Series([1, 2], index=native_idx2, name="zyx") snow_idx = s.index - assert_series_equal(snow_idx._parent, s) + assert_series_equal(snow_idx._parent._parent, s) assert_index_equal(snow_idx, native_idx2) diff --git a/tests/integ/modin/index/test_index_methods.py b/tests/integ/modin/index/test_index_methods.py index 8d0434915ac..6b33eb89889 100644 --- a/tests/integ/modin/index/test_index_methods.py +++ b/tests/integ/modin/index/test_index_methods.py @@ -393,13 +393,13 @@ def test_index_parent(): # DataFrame case. df = pd.DataFrame([[1, 2], [3, 4]], index=native_idx1) snow_idx = df.index - assert_frame_equal(snow_idx._parent, df) + assert_frame_equal(snow_idx._parent._parent, df) assert_index_equal(snow_idx, native_idx1) # Series case. s = pd.Series([1, 2, 4, 5, 6, 7], index=native_idx2, name="zyx") snow_idx = s.index - assert_series_equal(snow_idx._parent, s) + assert_series_equal(snow_idx._parent._parent, s) assert_index_equal(snow_idx, native_idx2) diff --git a/tests/integ/modin/index/test_name.py b/tests/integ/modin/index/test_name.py index b916110f386..f915598c5f6 100644 --- a/tests/integ/modin/index/test_name.py +++ b/tests/integ/modin/index/test_name.py @@ -351,3 +351,69 @@ def test_index_names_with_lazy_index(): ), inplace=True, ) + + +@sql_count_checker(query_count=1) +def test_index_names_replace_behavior(): + """ + Check that the index name of a DataFrame cannot be updated after the DataFrame has been modified. + """ + data = { + "A": [0, 1, 2, 3, 4, 4], + "B": ["a", "b", "c", "d", "e", "f"], + } + idx = [1, 2, 3, 4, 5, 6] + native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test")) + snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test")) + + # Get a reference to the index of the DataFrames. + snow_index = snow_df.index + native_index = native_df.index + + # Change the names. + snow_index.name = "test2" + native_index.name = "test2" + + # Compare the names. + assert snow_index.name == native_index.name == "test2" + assert snow_df.index.name == native_df.index.name == "test2" + + # Change the query compiler the DataFrame is referring to, change the names. + snow_df.dropna(inplace=True) + native_df.dropna(inplace=True) + snow_index.name = "test3" + native_index.name = "test3" + + # Compare the names. Changing the index name should not change the DataFrame's index name. + assert snow_index.name == native_index.name == "test3" + assert snow_df.index.name == native_df.index.name == "test2" + + +@sql_count_checker(query_count=1) +def test_index_names_multiple_renames(): + """ + Check that the index name of a DataFrame can be renamed any number of times. + """ + data = { + "A": [0, 1, 2, 3, 4, 4], + "B": ["a", "b", "c", "d", "e", "f"], + } + idx = [1, 2, 3, 4, 5, 6] + native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test")) + snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test")) + + # Get a reference to the index of the DataFrames. + snow_index = snow_df.index + native_index = native_df.index + + # Change and compare the names. + snow_index.name = "test2" + native_index.name = "test2" + assert snow_index.name == native_index.name == "test2" + assert snow_df.index.name == native_df.index.name == "test2" + + # Change the names again and compare. + snow_index.name = "test3" + native_index.name = "test3" + assert snow_index.name == native_index.name == "test3" + assert snow_df.index.name == native_df.index.name == "test3"