Skip to content

Commit

Permalink
SNOW-1661142 Fix index name behavior (#2274)
Browse files Browse the repository at this point in the history
1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.
   Fixes SNOW-1661142

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.

3. Please describe how your code solves the related issue.
Fixed a bug where updating an index's name updates the parent's index
name when it is not supposed to. This is done by verifying that the
query_compiler recorded during the index's creation matches that of its
parent object when the parent object must be updated.
```py
>>> df = pd.DataFrame(
...     {
...         "A": [0, 1, 2, 3, 4, 4],
...         "B": ['a', 'b', 'c', 'd', 'e', 'f'],
...     },
...     index = pd.Index([1, 2, 3, 4, 5, 6], name = "test"),
... )
>>> index = df.index
>>> df
      A  B
test      
1     0  a
2     1  b
3     2  c
4     3  d
5     4  e
6     4  f
>>> index.name = "test2"
>>> 
>>> df
       A  B
test2      
1      0  a
2      1  b
3      2  c
4      3  d
5      4  e
6      4  f
>>> df.dropna(inplace=True)
>>> index.name = "test3"
>>> df
       A  B
test2        # <--- name should not update      
1      0  a
2      1  b
3      2  c
4      3  d
5      4  e
6      4  f
```
For the full discussion, see thread:
https://docs.google.com/document/d/1vdllzNgeUHMiffFNpm9SD1HOYUk8lkMVp14HQDoqr7s/edit?disco=AAABVbKjFJ0
  • Loading branch information
sfc-gh-vbudati authored Sep 17, 2024
1 parent 0ee3033 commit e93cd68
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 17 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
- Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`.
- Added support for `by`, `left_by`, and `right_by` for `pd.merge_asof`.

#### Bug Fixes

- Fixed a bug where an `Index` object created from a `Series`/`DataFrame` incorrectly updates the `Series`/`DataFrame`'s index name after an inplace update has been applied to the original `Series`/`DataFrame`.


## 1.22.1 (2024-09-11)
This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content.
Expand Down
61 changes: 48 additions & 13 deletions src/snowflake/snowpark/modin/plugin/extensions/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,35 @@
}


class IndexParent:
def __init__(self, parent: DataFrame | Series) -> None:
"""
Initialize the IndexParent object.
IndexParent is used to keep track of the parent object that the Index is a part of.
It tracks the parent object and the parent object's query compiler at the time of creation.
Parameters
----------
parent : DataFrame or Series
The parent object that the Index is a part of.
"""
assert isinstance(parent, (DataFrame, Series))
self._parent = parent
self._parent_qc = parent._query_compiler

def check_and_update_parent_qc_index_names(self, names: list) -> None:
"""
Update the Index and its parent's index names if the query compiler associated with the parent is
different from the original query compiler recorded, i.e., an inplace update has been applied to the parent.
"""
if self._parent._query_compiler is self._parent_qc:
new_query_compiler = self._parent_qc.set_index_names(names)
self._parent._update_inplace(new_query_compiler=new_query_compiler)
# Update the query compiler after naming operation.
self._parent_qc = new_query_compiler


class Index(metaclass=TelemetryMeta):

# Equivalent index type in native pandas
Expand Down Expand Up @@ -135,7 +164,7 @@ def __new__(
index = object.__new__(cls)
# Initialize the Index
index._query_compiler = query_compiler
# `_parent` keeps track of any Series or DataFrame that this Index is a part of.
# `_parent` keeps track of the parent object that this Index is a part of.
index._parent = None
return index

Expand Down Expand Up @@ -252,6 +281,17 @@ def __getattr__(self, key: str) -> Any:
ErrorMessage.not_implemented(f"Index.{key} is not yet implemented")
raise err

def _set_parent(self, parent: Series | DataFrame) -> None:
"""
Set the parent object and its query compiler.
Parameters
----------
parent : Series or DataFrame
The parent object that the Index is a part of.
"""
self._parent = IndexParent(parent)

def _binary_ops(self, method: str, other: Any) -> Index:
if isinstance(other, Index):
other = other.to_series().reset_index(drop=True)
Expand Down Expand Up @@ -408,12 +448,6 @@ def __constructor__(self):
"""
return type(self)

def _set_parent(self, parent: Series | DataFrame):
"""
Set the parent object of the current Index to a given Series or DataFrame.
"""
self._parent = parent

@property
def values(self) -> ArrayLike:
"""
Expand Down Expand Up @@ -726,10 +760,11 @@ def name(self, value: Hashable) -> None:
if not is_hashable(value):
raise TypeError(f"{type(self).__name__}.name must be a hashable type")
self._query_compiler = self._query_compiler.set_index_names([value])
# Update the name of the parent's index only if an inplace update is performed on
# the parent object, i.e., the parent's current query compiler matches the originally
# recorded query compiler.
if self._parent is not None:
self._parent._update_inplace(
new_query_compiler=self._parent._query_compiler.set_index_names([value])
)
self._parent.check_and_update_parent_qc_index_names([value])

def _get_names(self) -> list[Hashable]:
"""
Expand All @@ -755,10 +790,10 @@ def _set_names(self, values: list) -> None:
if isinstance(values, Index):
values = values.to_list()
self._query_compiler = self._query_compiler.set_index_names(values)
# Update the name of the parent's index only if the parent's current query compiler
# matches the recorded query compiler.
if self._parent is not None:
self._parent._update_inplace(
new_query_compiler=self._parent._query_compiler.set_index_names(values)
)
self._parent.check_and_update_parent_qc_index_names(values)

names = property(fset=_set_names, fget=_get_names)

Expand Down
4 changes: 2 additions & 2 deletions tests/integ/modin/index/test_datetime_index_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ def test_index_parent():
# DataFrame case.
df = pd.DataFrame({"A": [1]}, index=native_idx1)
snow_idx = df.index
assert_frame_equal(snow_idx._parent, df)
assert_frame_equal(snow_idx._parent._parent, df)
assert_index_equal(snow_idx, native_idx1)

# Series case.
s = pd.Series([1, 2], index=native_idx2, name="zyx")
snow_idx = s.index
assert_series_equal(snow_idx._parent, s)
assert_series_equal(snow_idx._parent._parent, s)
assert_index_equal(snow_idx, native_idx2)


Expand Down
4 changes: 2 additions & 2 deletions tests/integ/modin/index/test_index_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,13 @@ def test_index_parent():
# DataFrame case.
df = pd.DataFrame([[1, 2], [3, 4]], index=native_idx1)
snow_idx = df.index
assert_frame_equal(snow_idx._parent, df)
assert_frame_equal(snow_idx._parent._parent, df)
assert_index_equal(snow_idx, native_idx1)

# Series case.
s = pd.Series([1, 2, 4, 5, 6, 7], index=native_idx2, name="zyx")
snow_idx = s.index
assert_series_equal(snow_idx._parent, s)
assert_series_equal(snow_idx._parent._parent, s)
assert_index_equal(snow_idx, native_idx2)


Expand Down
66 changes: 66 additions & 0 deletions tests/integ/modin/index/test_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,69 @@ def test_index_names_with_lazy_index():
),
inplace=True,
)


@sql_count_checker(query_count=1)
def test_index_names_replace_behavior():
"""
Check that the index name of a DataFrame cannot be updated after the DataFrame has been modified.
"""
data = {
"A": [0, 1, 2, 3, 4, 4],
"B": ["a", "b", "c", "d", "e", "f"],
}
idx = [1, 2, 3, 4, 5, 6]
native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test"))
snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test"))

# Get a reference to the index of the DataFrames.
snow_index = snow_df.index
native_index = native_df.index

# Change the names.
snow_index.name = "test2"
native_index.name = "test2"

# Compare the names.
assert snow_index.name == native_index.name == "test2"
assert snow_df.index.name == native_df.index.name == "test2"

# Change the query compiler the DataFrame is referring to, change the names.
snow_df.dropna(inplace=True)
native_df.dropna(inplace=True)
snow_index.name = "test3"
native_index.name = "test3"

# Compare the names. Changing the index name should not change the DataFrame's index name.
assert snow_index.name == native_index.name == "test3"
assert snow_df.index.name == native_df.index.name == "test2"


@sql_count_checker(query_count=1)
def test_index_names_multiple_renames():
"""
Check that the index name of a DataFrame can be renamed any number of times.
"""
data = {
"A": [0, 1, 2, 3, 4, 4],
"B": ["a", "b", "c", "d", "e", "f"],
}
idx = [1, 2, 3, 4, 5, 6]
native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test"))
snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test"))

# Get a reference to the index of the DataFrames.
snow_index = snow_df.index
native_index = native_df.index

# Change and compare the names.
snow_index.name = "test2"
native_index.name = "test2"
assert snow_index.name == native_index.name == "test2"
assert snow_df.index.name == native_df.index.name == "test2"

# Change the names again and compare.
snow_index.name = "test3"
native_index.name = "test3"
assert snow_index.name == native_index.name == "test3"
assert snow_df.index.name == native_df.index.name == "test3"

0 comments on commit e93cd68

Please sign in to comment.