From e92904e61ab3b14fe18d472df19311f9b014f6cc Mon Sep 17 00:00:00 2001 From: Vaggelis Danias Date: Tue, 29 Oct 2024 17:44:53 +0200 Subject: [PATCH] feat(spark)!: Transpile ANY to EXISTS (#4305) * feat(spark): Transpile ANY operator to EXISTS * Support EXISTS as a function --- sqlglot/dialects/databricks.py | 1 + sqlglot/dialects/hive.py | 7 +++++++ sqlglot/dialects/spark2.py | 1 + sqlglot/expressions.py | 8 ++++---- sqlglot/transforms.py | 27 +++++++++++++++++++++++++++ tests/dialects/test_databricks.py | 11 +++++++++++ tests/dialects/test_hive.py | 19 +++++++++++++++++++ tests/dialects/test_postgres.py | 10 ++++++++++ tests/fixtures/identity.sql | 1 - 9 files changed, 80 insertions(+), 5 deletions(-) diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 0715eb7e15..c6bde19680 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -87,6 +87,7 @@ class Generator(Spark.Generator): [ transforms.eliminate_distinct_on, transforms.unnest_to_explode, + transforms.any_to_exists, ] ), exp.JSONExtract: _jsonextract_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index c30e7ec908..db4c997286 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -556,6 +556,7 @@ class Generator(generator.Generator): transforms.eliminate_qualify, transforms.eliminate_distinct_on, partial(transforms.unnest_to_explode, unnest_using_arrays_zip=False), + transforms.any_to_exists, ] ), exp.StrPosition: strposition_to_locate_sql, @@ -709,3 +710,9 @@ def serdeproperties_sql(self, expression: exp.SerdeProperties) -> str: exprs = self.expressions(expression, flat=True) return f"{prefix}SERDEPROPERTIES ({exprs})" + + def exists_sql(self, expression: exp.Exists): + if expression.expression: + return self.function_fallback_sql(expression) + + return super().exists_sql(expression) diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 15c2d2d231..21947135dc 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -285,6 +285,7 @@ class Generator(Hive.Generator): transforms.eliminate_qualify, transforms.eliminate_distinct_on, transforms.unnest_to_explode, + transforms.any_to_exists, ] ), exp.StrToDate: _str_to_date, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 5facee3616..19dd979d41 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -4605,10 +4605,6 @@ class Any(SubqueryPredicate): pass -class Exists(SubqueryPredicate): - pass - - # Commands to interact with the databases or engines. For most of the command # expressions we parse whatever comes after the command's name as a string. class Command(Expression): @@ -5583,6 +5579,10 @@ class Extract(Func): arg_types = {"this": True, "expression": True} +class Exists(Func, SubqueryPredicate): + arg_types = {"this": True, "expression": False} + + class Timestamp(Func): arg_types = {"this": False, "zone": False, "with_tz": False} diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index e0e3e324d1..46c859b028 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -914,3 +914,30 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: where.pop() return expression + + +def any_to_exists(expression: exp.Expression) -> exp.Expression: + """ + Transform ANY operator to Spark's EXISTS + + For example, + - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) + - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) + + Both ANY and EXISTS accept queries but currently only array expressions are supported for this + transformation + """ + if isinstance(expression, exp.Select): + for any in expression.find_all(exp.Any): + this = any.this + if isinstance(this, exp.Query): + continue + + binop = any.parent + if isinstance(binop, exp.Binary): + lambda_arg = exp.to_identifier("x") + any.replace(lambda_arg) + lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) + binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) + + return expression diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index f7ec756d07..1dbbfa9ab0 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -116,6 +116,17 @@ def test_databricks(self): }, ) + self.validate_all( + "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)", + read={ + "databricks": "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)", + "spark": "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)", + }, + write={ + "spark": "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)", + }, + ) + # https://docs.databricks.com/sql/language-manual/functions/colonsign.html def test_json(self): self.validate_identity("SELECT c1:price, c1:price.foo, c1:price.bar[1]") diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index e40a85afd9..f13d92cf0c 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -1,5 +1,7 @@ from tests.dialects.test_dialect import Validator +from sqlglot import exp + class TestHive(Validator): dialect = "hive" @@ -787,6 +789,23 @@ def test_hive(self): }, ) + self.validate_identity("EXISTS(col, x -> x % 2 = 0)").assert_is(exp.Exists) + + self.validate_all( + "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + read={ + "hive": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + "spark2": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + "spark": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + "databricks": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + }, + write={ + "spark2": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + "spark": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + "databricks": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + }, + ) + def test_escapes(self) -> None: self.validate_identity("'\n'", "'\\n'") self.validate_identity("'\\n'") diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 0f9ab3c23b..4b54cd0245 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -804,6 +804,16 @@ def test_postgres(self): "duckdb": """SELECT JSON_EXISTS('{"a": [1,2,3]}', '$.a')""", }, ) + self.validate_all( + "WITH t AS (SELECT ARRAY[1, 2, 3] AS col) SELECT * FROM t WHERE 1 <= ANY(col) AND 2 = ANY(col)", + write={ + "postgres": "WITH t AS (SELECT ARRAY[1, 2, 3] AS col) SELECT * FROM t WHERE 1 <= ANY(col) AND 2 = ANY(col)", + "hive": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)", + "spark2": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)", + "spark": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)", + "databricks": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)", + }, + ) def test_ddl(self): # Checks that user-defined types are parsed into DataType instead of Identifier diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index bed250225d..33199de377 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -250,7 +250,6 @@ SELECT LEAD(a, 1) OVER (PARTITION BY a ORDER BY a) AS x SELECT LEAD(a, 1, b) OVER (PARTITION BY a ORDER BY a) AS x SELECT X((a, b) -> a + b, z -> z) AS x SELECT X(a -> a + ("z" - 1)) -SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0) SELECT test.* FROM test SELECT a AS b FROM test SELECT "a"."b" FROM "a"