Skip to content

Commit

Permalink
SNOW-1732432 Improved SQL query for iloc with scalar key (#2455)
Browse files Browse the repository at this point in the history
…s a scalar or a list which is convertable to a slice

<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.

Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

   Fixes SNOW-1732432

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.
- [x] I acknowledge that I have ensured my changes to be thread-safe.
Follow the link for more information: [Thread-safe Developer
Guidelines](https://docs.google.com/document/d/162d_i4zZ2AfcGRXojj0jByt8EUq-DrSHPPnTa4QvwbA/edit#bookmark=id.e82u4nekq80k)

3. Please describe how your code solves the related issue.

Please write a short description of how your code change solves the
related issue.

Convert the scalar key to slice to reuse the efficient query generated
for slice key iloc. This will avoid whole table scan for examples like
`df.iloc[0]`.
  • Loading branch information
sfc-gh-azhan authored Oct 18, 2024
1 parent b31b1d8 commit d203e69
Show file tree
Hide file tree
Showing 18 changed files with 133 additions and 105 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
- Improved generated SQL query for `head` and `iloc` when the row key is a slice.
- Improved error message when passing an unknown timezone to `tz_convert` and `tz_localize` in `Series`, `DataFrame`, `Series.dt`, and `DatetimeIndex`.
- Improved documentation for `tz_convert` and `tz_localize` in `Series`, `DataFrame`, `Series.dt`, and `DatetimeIndex` to specify the supported timezone formats.
- Improved generated SQL query for `iloc` and `iat` when the row key is a scalar.
- Removed all joins in `iterrows`.

#### Bug Fixes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1192,9 +1192,10 @@ def __getitem__(
if not isinstance(col_loc, pd.Series) and is_range_like(col_loc):
col_loc = self._convert_range_to_valid_slice(col_loc)

# Convert all scalar, list-like, and indexer row_loc to a Series object to get a query compiler object.
# Convert scalar to slice to generate efficient SQL query
if is_scalar(row_loc):
row_loc = pd.Series([row_loc])
row_loc = slice(row_loc, None if row_loc == -1 else row_loc + 1, 1)
# Convert list-like, and indexer row_loc to a Series object to get a query compiler object.
elif isinstance(row_loc, pd.Index):
# Convert index row_loc to series
row_loc = row_loc.to_series().reset_index(drop=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/modin/binary/test_binary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,7 +1945,7 @@ def test_binary_comparison_method_between_series_different_types(op):
@pytest.mark.parametrize(
"op", [operator.eq, operator.ne, operator.gt, operator.ge, operator.lt, operator.le]
)
@sql_count_checker(query_count=2, join_count=5)
@sql_count_checker(query_count=2, join_count=2)
def test_binary_comparison_method_between_series_variant(lhs, rhs, op):
snow_ans = op(pd.Series(lhs), pd.Series(rhs))
native_ans = op(native_pd.Series(lhs), native_pd.Series(rhs))
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/modin/frame/test_dropna.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_dropna_negative(test_dropna_df):
),
],
)
@sql_count_checker(query_count=1, join_count=4, union_count=1)
@sql_count_checker(query_count=1, union_count=1)
def test_dropna_iloc(df):
# Check that dropna() generates a new index correctly for iloc.
# 1 join for iloc, 2 joins generated by to_pandas methods during eval.
Expand Down
20 changes: 0 additions & 20 deletions tests/integ/modin/frame/test_head_tail.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,3 @@ def test_empty_dataframe(n, empty_snowpark_pandas_df):
comparator=eval_result_and_query_with_no_join,
check_column_type=False,
)


@pytest.mark.parametrize(
"ops",
[
lambda df: df.head(),
lambda df: df.iloc[1:100],
lambda df: df.iloc[1000:100:-1],
],
)
@sql_count_checker(query_count=6)
def test_head_efficient_sql(session, ops):
df = DataFrame({"a": [1] * 10000})
with session.query_history() as query_listener:
ops(df).to_pandas()
eval_query = query_listener.queries[-2].sql_text.lower()
# check no row count
assert "count" not in eval_query
# check orderBy behinds limit
assert eval_query.index("limit") < eval_query.index("order by")
12 changes: 6 additions & 6 deletions tests/integ/modin/frame/test_iat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
(-7, -7),
],
)
@sql_count_checker(query_count=1, join_count=2)
@sql_count_checker(query_count=1)
def test_iat_get_default_index_str_columns(
key,
default_index_snowpark_pandas_df,
Expand Down Expand Up @@ -62,7 +62,7 @@ def iat_set_helper(df):
(-7, -7),
],
)
@sql_count_checker(query_count=1, join_count=2)
@sql_count_checker(query_count=1)
def test_iat_get_str_index_str_columns(
key,
str_index_snowpark_pandas_df,
Expand Down Expand Up @@ -103,7 +103,7 @@ def iat_set_helper(df):
(-7, -7),
],
)
@sql_count_checker(query_count=1, join_count=2)
@sql_count_checker(query_count=1)
def test_iat_get_time_index_time_columns(
key,
time_index_snowpark_pandas_df,
Expand Down Expand Up @@ -147,7 +147,7 @@ def iat_set_helper(df):
(-7, -7),
],
)
@sql_count_checker(query_count=1, join_count=2)
@sql_count_checker(query_count=1)
def test_iat_get_multiindex_index_str_columns(
key,
default_index_native_df,
Expand Down Expand Up @@ -190,7 +190,7 @@ def at_set_helper(df):
(-7, -7),
],
)
@sql_count_checker(query_count=1, join_count=2)
@sql_count_checker(query_count=1)
def test_iat_get_default_index_multiindex_columns(
key,
native_df_with_multiindex_columns,
Expand Down Expand Up @@ -231,7 +231,7 @@ def at_set_helper(df):
(-7, -7),
],
)
@sql_count_checker(query_count=1, join_count=2)
@sql_count_checker(query_count=1)
def test_iat_get_multiindex_index_multiindex_columns(
key,
native_df_with_multiindex_columns,
Expand Down
88 changes: 66 additions & 22 deletions tests/integ/modin/frame/test_iloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@

test_negative_bound_list_input = [([-AXIS_LEN - 0.9], 1, 2)]
test_int_inputs = [
(0, 1, 4),
(AXIS_LEN - 1, 1, 4),
(-AXIS_LEN, 1, 4),
(0, 1, 0),
(AXIS_LEN - 1, 1, 0),
(-AXIS_LEN, 1, 0),
]
test_inputs_on_df_for_dataframe_output = (
test_int_inputs + test_inputs_for_no_scalar_output
Expand Down Expand Up @@ -328,7 +328,7 @@ def eval_func(df):
(..., 1, 2), # leading ellipsis should be stripped
],
)
@sql_count_checker(query_count=1, join_count=2)
@sql_count_checker(query_count=1)
def test_df_iloc_get_scalar(
key, default_index_snowpark_pandas_df, default_index_native_df
):
Expand Down Expand Up @@ -423,7 +423,7 @@ def test_df_iloc_get_empty_key(
)


@sql_count_checker(query_count=2, join_count=2)
@sql_count_checker(query_count=2)
def test_df_iloc_get_empty(empty_snowpark_pandas_df):
_ = empty_snowpark_pandas_df.iloc[0]

Expand Down Expand Up @@ -1088,22 +1088,21 @@ def iloc_helper(df):
else:
return native_pd.Series([]) if axis == "row" else df.iloc[:, []]

def determine_query_and_join_count():
def determine_query_count():
# Multiple queries because of squeeze() - in range is 2, out-of-bounds is 1.
if axis == "col":
num_joins = 0
num_queries = 1
else:
if not -8 < key < 7: # key is out of bound
num_queries, num_joins = 2, 8
num_queries = 2
else:
num_queries, num_joins = 1, 4
return num_queries, num_joins
num_queries = 1
return num_queries

query_count, join_count = determine_query_and_join_count()
query_count = determine_query_count()
# test df with default index
num_cols = 7
with SqlCounter(query_count=query_count, join_count=join_count):
with SqlCounter(query_count=query_count):
eval_snowpark_pandas_result(
default_index_snowpark_pandas_df,
default_index_native_df,
Expand All @@ -1112,20 +1111,20 @@ def determine_query_and_join_count():

# test df with non-default index
num_cols = 6 # set_index() makes the number of columns 6
with SqlCounter(query_count=query_count, join_count=join_count):
with SqlCounter(query_count=query_count):
eval_snowpark_pandas_result(
default_index_snowpark_pandas_df.set_index("D"),
default_index_native_df.set_index("D"),
iloc_helper,
)

query_count, join_count = determine_query_and_join_count()
query_count = determine_query_count()
# test df with MultiIndex
# Index dtype is different between Snowpark and native pandas if key produces empty df.
num_cols = 7
native_df = default_index_native_df.set_index(multiindex_native)
snowpark_df = pd.DataFrame(native_df)
with SqlCounter(query_count=query_count, join_count=join_count):
with SqlCounter(query_count=query_count):
eval_snowpark_pandas_result(
snowpark_df,
native_df,
Expand All @@ -1138,7 +1137,7 @@ def determine_query_and_join_count():
native_df_with_multiindex_columns
)
in_range = True if (-8 < key < 7) else False
with SqlCounter(query_count=query_count, join_count=join_count):
with SqlCounter(query_count=query_count):
if axis == "row" or in_range: # series result
eval_snowpark_pandas_result(
snowpark_df_with_multiindex_columns,
Expand All @@ -1158,7 +1157,7 @@ def determine_query_and_join_count():
# test df with MultiIndex on both index and columns
native_df = native_df_with_multiindex_columns.set_index(multiindex_native)
snowpark_df = pd.DataFrame(native_df)
with SqlCounter(query_count=query_count, join_count=join_count):
with SqlCounter(query_count=query_count):
if axis == "row" or in_range: # series result
eval_snowpark_pandas_result(
snowpark_df,
Expand Down Expand Up @@ -2906,10 +2905,12 @@ def iloc_helper(df):
def determine_query_and_join_count():
# Initialize count values; query_count = row_count + col_count.
query_count = 1 # base query count
# All scalar and list-like row keys are treated like Series keys; a join is performed between the df and
# All list-like row keys are treated like Series keys; a join is performed between the df and
# key. For slice and range keys, a filter is used on the df instead.
join_count = 2
if not isinstance(row, list) or len(row) > 0:
if is_scalar(row):
join_count = 0
elif not isinstance(row, list) or len(row) > 0:
if is_range_like(row) or isinstance(row, slice):
join_count = 0
elif all(isinstance(i, bool) or isinstance(i, np.bool_) for i in row):
Expand All @@ -2934,7 +2935,7 @@ def determine_query_and_join_count():
(1, native_pd.Series([False, False, False, False, False, True, True])),
],
)
@sql_count_checker(query_count=2, join_count=4, union_count=1)
@sql_count_checker(query_count=2, union_count=1)
def test_df_iloc_get_array_col(
row,
col,
Expand Down Expand Up @@ -3124,7 +3125,7 @@ def iloc_helper(df):
col_in_range = True if col_lower_bound < col < col_upper_bound else False
if row_in_range and col_in_range:
# scalar value is returned
with SqlCounter(query_count=1, join_count=2):
with SqlCounter(query_count=1):
snowpark_res = (
snowpark_df.iloc[(row, col)]
if is_tuple
Expand All @@ -3137,7 +3138,7 @@ def iloc_helper(df):
for idx, val in enumerate(snowpark_res):
assert val == native_res[idx]
else:
with SqlCounter(query_count=1, join_count=2):
with SqlCounter(query_count=1):
with pytest.raises(IndexError):
iloc_helper(native_df)
assert len(iloc_helper(snowpark_df)) == 0
Expand Down Expand Up @@ -3252,3 +3253,46 @@ def ilocset(df):
native_df,
ilocset,
)


@pytest.mark.parametrize(
"ops",
[
lambda df: df.head(),
lambda df: df.iloc[1:100],
lambda df: df.iloc[1000:100:-1],
],
)
@sql_count_checker(query_count=6)
def test_df_iloc_efficient_sql(session, ops):
df = DataFrame({"a": [1] * 10000})
with session.query_history() as query_listener:
ops(df).to_pandas()
eval_query = query_listener.queries[
-2
].sql_text.lower() # query before drop temp table
# check no row count is in the sql query
assert "count" not in eval_query
# check orderBy is after limit in the sql query
assert eval_query.index("limit") < eval_query.index("order by")


@pytest.mark.parametrize(
"ops",
[
lambda df: df.iloc[0],
lambda df: df.iloc[100],
],
)
@sql_count_checker(query_count=8, union_count=1)
def test_df_iloc_scalar_efficient_sql(session, ops):
df = DataFrame({"a": [1] * 10000})
with session.query_history() as query_listener:
ops(df).to_pandas()
eval_query = query_listener.queries[
-3
].sql_text.lower() # query before drop temp table and transpose
# check no row count is in the sql query
assert "count" not in eval_query
# check limit is used in the sql query
assert "limit" in eval_query
6 changes: 3 additions & 3 deletions tests/integ/modin/frame/test_iterrows.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_df_iterrows(native_df):
snowpark_df = pd.DataFrame(native_df)
# One query is used to get the number of rows. One query is used to retrieve each row - each query has 4 JOIN
# operations performed due to iloc.
with SqlCounter(query_count=len(native_df) + 1, join_count=4 * len(native_df)):
with SqlCounter(query_count=len(native_df) + 1):
eval_snowpark_pandas_result(
snowpark_df,
native_df,
Expand All @@ -70,7 +70,7 @@ def test_df_iterrows(native_df):
)


@sql_count_checker(query_count=8, join_count=28, union_count=7)
@sql_count_checker(query_count=8, union_count=7)
def test_df_iterrows_mixed_types(default_index_native_df):
# Same test as above on bigger df with mixed types.
# One query is used to get the number of rows. One query is used to retrieve each row - each query has 4 JOIN
Expand All @@ -85,7 +85,7 @@ def test_df_iterrows_mixed_types(default_index_native_df):
)


@sql_count_checker(query_count=7, join_count=24, union_count=6)
@sql_count_checker(query_count=7, union_count=6)
def test_df_iterrows_multindex_df():
# Create df with a MultiIndex index.
# One query is used to get the number of rows. One query is used to retrieve each row - each query has 4 JOIN
Expand Down
6 changes: 1 addition & 5 deletions tests/integ/modin/frame/test_nunique.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@
@pytest.mark.parametrize("axes_slices", TEST_SLICES)
@pytest.mark.parametrize("dropna", [True, False])
def test_dataframe_nunique(axes_slices, dropna):
expected_join_count = 0
if axes_slices == (0, slice(None)):
expected_join_count = 4

df = pd.DataFrame(
pd.DataFrame(TEST_DATA, columns=TEST_LABELS).iloc[
axes_slices[0], axes_slices[1]
Expand All @@ -46,7 +42,7 @@ def test_dataframe_nunique(axes_slices, dropna):
]
)

with SqlCounter(query_count=1, join_count=expected_join_count):
with SqlCounter(query_count=1):
eval_snowpark_pandas_result(
df,
native_df,
Expand Down
3 changes: 2 additions & 1 deletion tests/integ/modin/index/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pandas as native_pd
import pytest
from modin.pandas.utils import is_scalar

import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.utils import assert_index_equal
Expand All @@ -31,7 +32,7 @@
],
)
def test_index_indexing(index, key):
if isinstance(key, slice) or key is ...:
if isinstance(key, slice) or key is ... or is_scalar(key):
join_count = 0 # because slice key uses filter not join
elif isinstance(key, list) and isinstance(key[0], bool):
join_count = 1 # because need to join key
Expand Down
Loading

0 comments on commit d203e69

Please sign in to comment.