Skip to content

Commit

Permalink
SNOW-1344938: Add support for DataFrame.corr (#1857)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-helmeleegy authored Jul 1, 2024
1 parent 990c47e commit 8a42763
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#### New Features
- Added partial support for `Series.str.translate` where the values in the `table` are single-codepoint strings.
- Added support for `DataFrame.corr`.

## 1.19.0 (2024-06-25)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/modin/supported/dataframe_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``copy`` | Y | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``corr`` | N | | |
| ``corr`` | P | | ``N`` if ``method`` is not 'pearson' |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``corrwith`` | N | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Expand Down
20 changes: 11 additions & 9 deletions src/snowflake/snowpark/modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,23 +749,25 @@ def compare(
)
)

@dataframe_not_implemented()
def corr(
self, method="pearson", min_periods=1, numeric_only=False
self,
method: str | Callable = "pearson",
min_periods: int | None = None,
numeric_only: bool = False,
): # noqa: PR01, RT01, D200
"""
Compute pairwise correlation of columns, excluding NA/null values.
"""
# TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
if not numeric_only:
return self._default_to_pandas(
pandas.DataFrame.corr,
method=method,
min_periods=min_periods,
numeric_only=numeric_only,
corr_df = self
if numeric_only:
corr_df = self.drop(
columns=[
i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i])
]
)
return self.__constructor__(
query_compiler=self._query_compiler.corr(
query_compiler=corr_df._query_compiler.corr(
method=method,
min_periods=min_periods,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
coalesce,
col,
concat,
corr,
count,
count_distinct,
date_part,
Expand Down Expand Up @@ -15028,3 +15029,100 @@ def stack(
return qc.dropna(axis=0, how="any", thresh=None)
else:
return qc

def corr(
self,
method: Union[str, Callable] = "pearson",
min_periods: Optional[int] = 1,
) -> "SnowflakeQueryCompiler":
"""
Compute pairwise correlation of columns, excluding NA/null values.

Parameters
----------
method : {‘pearson’, ‘kendall’, ‘spearman’} or callable
Method of correlation:
pearson : standard correlation coefficient
kendall : Kendall Tau correlation coefficient
spearman : Spearman rank correlation
callable: callable with input two 1d ndarrays
and returning a float. Note that the returned matrix from corr will have 1 along the diagonals and will be symmetric regardless of the callable’s behavior.

min_periods : int, optional
Minimum number of observations required per pair of columns to have a valid result. Currently only available for Pearson and Spearman correlation.
"""
if not isinstance(method, str):
ErrorMessage.not_implemented(
"Snowpark pandas DataFrame.corr does not yet support non-string 'method'"
)

if method != "pearson":
ErrorMessage.not_implemented(
f"Snowpark pandas DataFrame.corr does not yet support 'method={method}'"
)

if min_periods is None:
min_periods = 1

frame = self._modin_frame

query_compilers = []
for outer_pandas_label, outer_quoted_identifier in zip(
frame.data_column_pandas_labels,
frame.data_column_snowflake_quoted_identifiers,
):
index_quoted_identifier = (
frame.ordered_dataframe.generate_snowflake_quoted_identifiers(
pandas_labels=[INDEX_LABEL],
)[0]
)

# Apply a "min" function to the index column to make sure it's also an aggregate.
index_col = min_(pandas_lit(outer_pandas_label)).as_(
index_quoted_identifier
)

new_columns = [index_col]
for (
inner_quoted_identifier
) in frame.data_column_snowflake_quoted_identifiers:
new_col = corr(outer_quoted_identifier, inner_quoted_identifier)
if min_periods > 1:
outer_col_is_valid = builtin("count_if")(
col(outer_quoted_identifier).is_not_null()
) >= pandas_lit(min_periods)
inner_col_is_valid = builtin("count_if")(
col(inner_quoted_identifier).is_not_null()
) >= pandas_lit(min_periods)
new_col = iff(
outer_col_is_valid & inner_col_is_valid,
new_col,
pandas_lit(None),
)
new_col = new_col.as_(inner_quoted_identifier)
new_columns.append(new_col)

new_ordered_data_frame = OrderedDataFrame(
dataframe_ref=DataFrameReference(
frame.ordered_dataframe._dataframe_ref.snowpark_dataframe.agg(
new_columns
)
)
)

new_frame = InternalFrame.create(
ordered_dataframe=new_ordered_data_frame,
data_column_pandas_labels=frame.data_column_pandas_labels,
data_column_pandas_index_names=[None],
data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers,
index_column_pandas_labels=[None],
index_column_snowflake_quoted_identifiers=[index_quoted_identifier],
)

query_compilers.append(SnowflakeQueryCompiler(new_frame))

if len(query_compilers) == 1:
result = query_compilers[0]
else:
result = query_compilers[0].concat(axis=0, other=query_compilers[1:])
return result
55 changes: 55 additions & 0 deletions src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,61 @@ def compare():
def corr():
"""
Compute pairwise correlation of columns, excluding NA/null values.
Parameters
----------
method : {‘pearson’, ‘kendall’, ‘spearman’} or callable
Method of correlation:
pearson : standard correlation coefficient
kendall : Kendall Tau correlation coefficient
spearman : Spearman rank correlation
callable: callable with input two 1d ndarrays
and returning a float. Note that the returned matrix from corr will have 1 along the diagonals and will be symmetric regardless of the callable’s behavior.
min_periods : int, optional
Minimum number of observations required per pair of columns to have a valid result. Currently only available for Pearson and Spearman correlation.
numeric_only : bool, default False
Include only float, int or boolean data.
Returns
-------
DataFrame
Correlation matrix.
See also
--------
DataFrame.corrwith
Compute pairwise correlation with another DataFrame or Series.
Series.corr
Compute the correlation between two Series.
Notes
-----
Pearson, Kendall and Spearman correlation are currently computed using pairwise complete observations.
Pearson correlation coefficient
Kendall rank correlation coefficient
Spearman’s rank correlation coefficient
Examples
--------
>>> def histogram_intersection(a, b):
... v = np.minimum(a, b).sum().round(decimals=1)
... return v
>>> df = pd.DataFrame([(.2, .3), (.0, .6), (.6, .0), (.2, .1)],
... columns=['dogs', 'cats'])
>>> df.corr(method=histogram_intersection) # doctest: +SKIP
dogs cats
dogs 1.0 0.3
cats 0.3 1.0
>>> df = pd.DataFrame([(1, 1), (2, np.nan), (np.nan, 3), (4, 4)],
... columns=['dogs', 'cats'])
>>> df.corr(min_periods=3)
dogs cats
dogs 1.0 1.0
cats 1.0 1.0
"""

def corrwith():
Expand Down
46 changes: 38 additions & 8 deletions tests/integ/modin/frame/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def native_df_multiindex() -> native_pd.DataFrame:
(lambda df: df.std(), 0),
(lambda df: df.var(), 0),
(lambda df: df.quantile(), 0),
(lambda df: df.corr(), 2),
(lambda df: df.aggregate(["idxmin"]), 0),
(
lambda df: df.aggregate(
Expand All @@ -102,6 +103,32 @@ def test_agg_basic(numeric_native_df, func, expected_union_count):
eval_snowpark_pandas_result(snow_df, numeric_native_df, func)


@pytest.mark.parametrize("min_periods", [None, -1, 0, 1, 2, 3, 4])
@sql_count_checker(query_count=1, union_count=2)
def test_corr_min_periods(min_periods):
snow_df, pandas_df = create_test_dfs(
{"a": [None, 1, 2], "b": [3, 4, 5], "c": [6, 7, 8]}
)
eval_snowpark_pandas_result(
snow_df, pandas_df, lambda df: df.corr(min_periods=min_periods)
)


@pytest.mark.parametrize(
"method",
[
"kendall",
"spearman",
lambda x, y: np.minimum(x, y).sum().round(decimals=1),
],
)
@sql_count_checker(query_count=0)
def test_corr_negative(numeric_native_df, method):
snow_df = pd.DataFrame(numeric_native_df)
with pytest.raises(NotImplementedError):
snow_df.corr(method=method)


@pytest.mark.parametrize(
"data",
[
Expand Down Expand Up @@ -307,6 +334,7 @@ def test_multiple_agg_on_only_dup_columns(self, numeric_native_df):
(lambda df: df.median(numeric_only=True), 0),
(lambda df: df.std(numeric_only=True), 0),
(lambda df: df.var(numeric_only=True), 0),
(lambda df: df.corr(numeric_only=True), 1),
(lambda df: df.aggregate("max"), 0),
],
)
Expand Down Expand Up @@ -553,6 +581,7 @@ def test_agg_with_no_column_raises(pandas_df):
lambda df: df.aggregate(min),
lambda df: df.max(),
lambda df: df.count(),
lambda df: df.corr(),
lambda df: df.aggregate(x=("A", "min")),
],
)
Expand All @@ -564,19 +593,20 @@ def test_agg_with_single_col(func):


@pytest.mark.parametrize(
"func",
"func, expected_union_count",
[
lambda df: df.aggregate(min),
lambda df: df.max(),
lambda df: df.count(),
lambda df: df.aggregate(x=("A", "min")),
(lambda df: df.aggregate(min), 0),
(lambda df: df.max(), 0),
(lambda df: df.count(), 0),
(lambda df: df.corr(), 1),
(lambda df: df.aggregate(x=("A", "min")), 0),
],
)
@sql_count_checker(query_count=1)
def test_agg_with_multi_col(func):
def test_agg_with_multi_col(func, expected_union_count):
native_df = native_pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
snow_df = pd.DataFrame(native_df)
eval_snowpark_pandas_result(snow_df, native_df, func)
with SqlCounter(query_count=1, union_count=expected_union_count):
eval_snowpark_pandas_result(snow_df, native_df, func)


@pytest.mark.parametrize(
Expand Down
1 change: 0 additions & 1 deletion tests/unit/modin/test_unsupported.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def test_unsupported_general(general_method, kwargs):
["combine", {"other": "", "func": ""}],
["combine_first", {"other": ""}],
["compare", {"other": ""}],
["corr", {}],
["corrwith", {"other": ""}],
["cov", {}],
["dot", {"other": ""}],
Expand Down

0 comments on commit 8a42763

Please sign in to comment.