Skip to content

Commit

Permalink
fix(generator): Add NULL FILTER on ARRAY_AGG only for columns (#4301)
Browse files Browse the repository at this point in the history
  • Loading branch information
VaggelisD authored Oct 28, 2024
1 parent 551afff commit a66e721
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
19 changes: 12 additions & 7 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4362,20 +4362,25 @@ def jsonexists_sql(self, expression: exp.JSONExists) -> str:
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
array_agg = self.function_fallback_sql(expression)

# Add a NULL FILTER on the column to mimic the results going from a dialect that excludes nulls
# on ARRAY_AGG (e.g Spark) to one that doesn't (e.g. DuckDB)
if self.dialect.ARRAY_AGG_INCLUDES_NULLS and expression.args.get("nulls_excluded"):
parent = expression.parent
if isinstance(parent, exp.Filter):
parent_cond = parent.expression.this
parent_cond.replace(parent_cond.and_(expression.this.is_(exp.null()).not_()))
else:
# DISTINCT is already present in the agg function, do not propagate it to FILTER as well
this = expression.this
this_sql = (
self.expressions(this)
if isinstance(this, exp.Distinct)
else self.sql(expression, "this")
)
array_agg = f"{array_agg} FILTER(WHERE {this_sql} IS NOT NULL)"
# Do not add the filter if the input is not a column (e.g. literal, struct etc)
if this.find(exp.Column):
# DISTINCT is already present in the agg function, do not propagate it to FILTER as well
this_sql = (
self.expressions(this)
if isinstance(this, exp.Distinct)
else self.sql(expression, "this")
)

array_agg = f"{array_agg} FILTER(WHERE {this_sql} IS NOT NULL)"

return array_agg

Expand Down
14 changes: 14 additions & 0 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,20 @@ def test_spark(self):
"spark": "SELECT COLLECT_LIST(x) FILTER(WHERE x = 5) FROM (SELECT 1 UNION ALL SELECT NULL) AS t(x)",
},
)
self.validate_all(
"SELECT ARRAY_AGG(1)",
write={
"duckdb": "SELECT ARRAY_AGG(1)",
"spark": "SELECT COLLECT_LIST(1)",
},
)
self.validate_all(
"SELECT ARRAY_AGG(DISTINCT STRUCT('a'))",
write={
"duckdb": "SELECT ARRAY_AGG(DISTINCT {'col1': 'a'})",
"spark": "SELECT COLLECT_LIST(DISTINCT STRUCT('a' AS col1))",
},
)
self.validate_all(
"SELECT DATE_FORMAT(DATE '2020-01-01', 'EEEE') AS weekday",
write={
Expand Down

0 comments on commit a66e721

Please sign in to comment.