-
Notifications
You must be signed in to change notification settings - Fork 142
SNOW-1051741: df.apply(axis=1) should preserve the original index #3955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
sfc-gh-helmeleegy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just had one question.
src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py
Outdated
Show resolved
Hide resolved
…b/snowpark-python into jkew/apply.axis.1.row.index.0
| if num_index_columns > 0: | ||
| # Columns after row position are index columns, then data columns | ||
| index_cols = df.iloc[:, 1 : 1 + num_index_columns] | ||
| data_cols = df.iloc[:, 1 + num_index_columns :] | ||
|
|
||
| # Set the index using the index columns | ||
| if num_index_columns == 1: | ||
| index = index_cols.iloc[:, 0] | ||
| if index_column_pandas_labels: | ||
| index.name = index_column_pandas_labels[0] | ||
| else: | ||
| # Multi-index case | ||
| index = native_pd.MultiIndex.from_arrays( | ||
| [index_cols.iloc[:, i] for i in range(num_index_columns)], | ||
| names=index_column_pandas_labels | ||
| if index_column_pandas_labels | ||
| else None, | ||
| ) | ||
| data_cols.index = index | ||
| df = data_cols | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we use set_index() in both cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant that you can replace most of the code here with set_index(). See #3979.
src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py
Outdated
Show resolved
Hide resolved
src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py
Outdated
Show resolved
Hide resolved
…b/snowpark-python into jkew/apply.axis.1.row.index.0
…b/snowpark-python into jkew/apply.axis.1.row.index.0
| input_types: Snowpark column types of the input data columns (including index columns). | ||
| index_column_pandas_labels: The pandas labels for the index columns, if any. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| input_types: Snowpark column types of the input data columns (including index columns). | |
| index_column_pandas_labels: The pandas labels for the index columns, if any. | |
| input_types: Snowpark column types of the input data columns (including index columns). |
|
|
||
|
|
||
| @sql_count_checker(query_count=5, join_count=2, udtf_count=1) | ||
| def test_apply_axis_1_multiindex_preservation(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we also test
funcwith return type annotations. We'll use vectorized UDFs instead of UDTFs.funcreturning a series- apply() on series (with func typed, untyped, or returning a series)
| if num_index_columns > 0: | ||
| # Columns after row position are index columns, then data columns | ||
| index_cols = df.iloc[:, 1 : 1 + num_index_columns] | ||
| data_cols = df.iloc[:, 1 + num_index_columns :] | ||
|
|
||
| # Set the index using the index columns | ||
| if num_index_columns == 1: | ||
| index = index_cols.iloc[:, 0] | ||
| if index_column_pandas_labels: | ||
| index.name = index_column_pandas_labels[0] | ||
| else: | ||
| # Multi-index case | ||
| index = native_pd.MultiIndex.from_arrays( | ||
| [index_cols.iloc[:, i] for i in range(num_index_columns)], | ||
| names=index_column_pandas_labels | ||
| if index_column_pandas_labels | ||
| else None, | ||
| ) | ||
| data_cols.index = index | ||
| df = data_cols | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant that you can replace most of the code here with set_index(). See #3979.
| # Determine if we should pass index columns to the UDTF | ||
| # We pass index columns when the index is not the row position itself |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We always pass the index column names here. We can keep doing that, but we should update the comment and make the parameter required, since there don't seem to be any other invocations of that function.
| column_index: native_pd.Index, | ||
| input_types: list[DataType], | ||
| session: Session, | ||
| index_column_labels: list[Hashable] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It turns out that just passing the number of index columns is enough:
| # columns. We don't care about the index names because `func` |
df.apply(axis=1)should preserve the original index. Previously we would return a RangeIndex regardless of the original index. This approach passes the index data into the underlying UDTF.Mostly AI written approach, but with original tests for verification.
Fixes SNOW-1051741
Fill out the following pre-review checklist: