Skip to content

Commit

Permalink
feat(spark)!: Transpile ANY to EXISTS (#4305)
Browse files Browse the repository at this point in the history
* feat(spark): Transpile ANY operator to EXISTS

* Support EXISTS as a function
  • Loading branch information
VaggelisD authored Oct 29, 2024
1 parent efd9b4e commit e92904e
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 5 deletions.
1 change: 1 addition & 0 deletions sqlglot/dialects/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class Generator(Spark.Generator):
[
transforms.eliminate_distinct_on,
transforms.unnest_to_explode,
transforms.any_to_exists,
]
),
exp.JSONExtract: _jsonextract_sql,
Expand Down
7 changes: 7 additions & 0 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ class Generator(generator.Generator):
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
partial(transforms.unnest_to_explode, unnest_using_arrays_zip=False),
transforms.any_to_exists,
]
),
exp.StrPosition: strposition_to_locate_sql,
Expand Down Expand Up @@ -709,3 +710,9 @@ def serdeproperties_sql(self, expression: exp.SerdeProperties) -> str:
exprs = self.expressions(expression, flat=True)

return f"{prefix}SERDEPROPERTIES ({exprs})"

def exists_sql(self, expression: exp.Exists):
if expression.expression:
return self.function_fallback_sql(expression)

return super().exists_sql(expression)
1 change: 1 addition & 0 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ class Generator(Hive.Generator):
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
transforms.unnest_to_explode,
transforms.any_to_exists,
]
),
exp.StrToDate: _str_to_date,
Expand Down
8 changes: 4 additions & 4 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4605,10 +4605,6 @@ class Any(SubqueryPredicate):
pass


class Exists(SubqueryPredicate):
pass


# Commands to interact with the databases or engines. For most of the command
# expressions we parse whatever comes after the command's name as a string.
class Command(Expression):
Expand Down Expand Up @@ -5583,6 +5579,10 @@ class Extract(Func):
arg_types = {"this": True, "expression": True}


class Exists(Func, SubqueryPredicate):
arg_types = {"this": True, "expression": False}


class Timestamp(Func):
arg_types = {"this": False, "zone": False, "with_tz": False}

Expand Down
27 changes: 27 additions & 0 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,3 +914,30 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
where.pop()

return expression


def any_to_exists(expression: exp.Expression) -> exp.Expression:
"""
Transform ANY operator to Spark's EXISTS
For example,
- Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
- Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
Both ANY and EXISTS accept queries but currently only array expressions are supported for this
transformation
"""
if isinstance(expression, exp.Select):
for any in expression.find_all(exp.Any):
this = any.this
if isinstance(this, exp.Query):
continue

binop = any.parent
if isinstance(binop, exp.Binary):
lambda_arg = exp.to_identifier("x")
any.replace(lambda_arg)
lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))

return expression
11 changes: 11 additions & 0 deletions tests/dialects/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ def test_databricks(self):
},
)

self.validate_all(
"SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)",
read={
"databricks": "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)",
"spark": "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)",
},
write={
"spark": "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)",
},
)

# https://docs.databricks.com/sql/language-manual/functions/colonsign.html
def test_json(self):
self.validate_identity("SELECT c1:price, c1:price.foo, c1:price.bar[1]")
Expand Down
19 changes: 19 additions & 0 deletions tests/dialects/test_hive.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from tests.dialects.test_dialect import Validator

from sqlglot import exp


class TestHive(Validator):
dialect = "hive"
Expand Down Expand Up @@ -787,6 +789,23 @@ def test_hive(self):
},
)

self.validate_identity("EXISTS(col, x -> x % 2 = 0)").assert_is(exp.Exists)

self.validate_all(
"SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
read={
"hive": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
"spark2": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
"spark": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
"databricks": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
},
write={
"spark2": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
"spark": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
"databricks": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
},
)

def test_escapes(self) -> None:
self.validate_identity("'\n'", "'\\n'")
self.validate_identity("'\\n'")
Expand Down
10 changes: 10 additions & 0 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,16 @@ def test_postgres(self):
"duckdb": """SELECT JSON_EXISTS('{"a": [1,2,3]}', '$.a')""",
},
)
self.validate_all(
"WITH t AS (SELECT ARRAY[1, 2, 3] AS col) SELECT * FROM t WHERE 1 <= ANY(col) AND 2 = ANY(col)",
write={
"postgres": "WITH t AS (SELECT ARRAY[1, 2, 3] AS col) SELECT * FROM t WHERE 1 <= ANY(col) AND 2 = ANY(col)",
"hive": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)",
"spark2": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)",
"spark": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)",
"databricks": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)",
},
)

def test_ddl(self):
# Checks that user-defined types are parsed into DataType instead of Identifier
Expand Down
1 change: 0 additions & 1 deletion tests/fixtures/identity.sql
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ SELECT LEAD(a, 1) OVER (PARTITION BY a ORDER BY a) AS x
SELECT LEAD(a, 1, b) OVER (PARTITION BY a ORDER BY a) AS x
SELECT X((a, b) -> a + b, z -> z) AS x
SELECT X(a -> a + ("z" - 1))
SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)
SELECT test.* FROM test
SELECT a AS b FROM test
SELECT "a"."b" FROM "a"
Expand Down

0 comments on commit e92904e

Please sign in to comment.