Skip to content

Commit

Permalink
[SNOW-1764119] Eliminate unnecessary join for np.where with scalar fo…
Browse files Browse the repository at this point in the history
…r x (#2568)

<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.

Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

   Fixes SNOW-1764119

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.
- [x] I acknowledge that I have ensured my changes to be thread-safe.
Follow the link for more information: [Thread-safe Developer
Guidelines](https://docs.google.com/document/d/162d_i4zZ2AfcGRXojj0jByt8EUq-DrSHPPnTa4QvwbA/edit#bookmark=id.e82u4nekq80k)

3. Please describe how your code solves the related issue.

np.where with scalar x today requires 
1) temp table creation for broadcasting the scalar x to shape of cond
2) a join when doing pandas where

However, for such case, there should be no need of the extra temp table
creation and join. In this change we removes the unnecessary temp table
creation and join by using the indexing for scalar cast.
  • Loading branch information
sfc-gh-yzou authored Nov 6, 2024
1 parent 32822dc commit de18ffa
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 14 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@
- Raise a `TypeError` for a scalar `subset` instead of filtering on just that column.
- Raise a `ValueError` for a `subset` of type `pandas.Index` instead of filtering on the columns in the index.

#### Improvements
- Improve np.where with scalar x value by eliminating unnecessary join and temp table creation.

### Snowpark Local Testing Updates

#### New Features
Expand Down
31 changes: 23 additions & 8 deletions src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,33 @@ def is_same_shape(
return x.where(cond, y) # type: ignore

if is_scalar(x):
# broadcast scalar x to size of cond
object_shape = cond.shape
if len(object_shape) == 1:
df_scalar = pd.Series(x, index=range(object_shape[0]))
elif len(object_shape) == 2:
df_scalar = pd.DataFrame(
x, index=range(object_shape[0]), columns=range(object_shape[1])
)
if cond.ndim == 1:
df_cond = cond.to_frame()
else:
df_cond = cond.copy()

origin_columns = df_cond.columns
# rename the columns of df_cond for ensure no conflict happens when
# appending new columns
renamed_columns = [f"col_{i}" for i in range(len(origin_columns))]
df_cond.columns = renamed_columns
# broadcast scalar x to size of cond through indexing
new_columns = [f"new_col_{i}" for i in range(len(origin_columns))]
df_cond[new_columns] = x

if cond.ndim == 1:
df_scalar = df_cond[new_columns[0]]
df_scalar.name = cond.name
else:
df_scalar = df_cond[new_columns]
# use the same name as the cond dataframe to make sure
# pandas where happens correctly
df_scalar.columns = origin_columns

# handles np.where(df, scalar1, scalar2)
# handles np.where(df1, scalar, df2)
return df_scalar.where(cond, y)

# return the sentinel NotImplemented if we do not support this function
return NotImplemented

Expand Down
13 changes: 7 additions & 6 deletions tests/integ/modin/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def test_np_where_notimplemented():
)


@sql_count_checker(query_count=5, join_count=4)
@sql_count_checker(query_count=4, join_count=1)
def test_scalar():
pdf_scalar = native_pd.DataFrame([[99, 99], [99, 99]])
sdf_scalar = pd.DataFrame([[99, 99], [99, 99]])
Expand Down Expand Up @@ -315,18 +315,19 @@ def test_different_inputs(cond, x, y):
assert_array_equal(sp_result, np_orig_result)


@sql_count_checker(query_count=2, join_count=2)
def test_broadcast_scalar_x_df():
input_df = native_pd.DataFrame([[False, True], [False, True]])
input_df2 = native_pd.DataFrame([[1, 0], [0, 1]])
@sql_count_checker(query_count=1, join_count=1)
@pytest.mark.parametrize("column_names", [None, ["A", "B"]])
def test_broadcast_scalar_x_df(column_names):
input_df = native_pd.DataFrame([[False, True], [False, True]], columns=column_names)
input_df2 = native_pd.DataFrame([[1, 0], [0, 1]], columns=column_names)
snow_df = pd.DataFrame(input_df)
snow_df2 = pd.DataFrame(input_df2)
snow_result = np.where(snow_df, -99, snow_df2)
np_result = np.where(input_df, -99, input_df2)
assert_array_equal(snow_result, np_result)


@sql_count_checker(query_count=2, join_count=2)
@sql_count_checker(query_count=1, join_count=1)
def test_broadcast_scalar_x_ser():
input_ser = native_pd.Series([False, True])
input_ser2 = native_pd.Series([1, 0])
Expand Down

0 comments on commit de18ffa

Please sign in to comment.