Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spark)!: Transpile ANY to EXISTS #4305

Merged
merged 2 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlglot/dialects/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class Generator(Spark.Generator):
[
transforms.eliminate_distinct_on,
transforms.unnest_to_explode,
transforms.any_to_exists,
]
),
exp.JSONExtract: _jsonextract_sql,
Expand Down
7 changes: 7 additions & 0 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ class Generator(generator.Generator):
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
partial(transforms.unnest_to_explode, unnest_using_arrays_zip=False),
transforms.any_to_exists,
]
),
exp.StrPosition: strposition_to_locate_sql,
Expand Down Expand Up @@ -709,3 +710,9 @@ def serdeproperties_sql(self, expression: exp.SerdeProperties) -> str:
exprs = self.expressions(expression, flat=True)

return f"{prefix}SERDEPROPERTIES ({exprs})"

def exists_sql(self, expression: exp.Exists):
if expression.expression:
return self.function_fallback_sql(expression)

return super().exists_sql(expression)
1 change: 1 addition & 0 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ class Generator(Hive.Generator):
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
transforms.unnest_to_explode,
transforms.any_to_exists,
]
),
exp.StrToDate: _str_to_date,
Expand Down
8 changes: 4 additions & 4 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4605,10 +4605,6 @@ class Any(SubqueryPredicate):
pass


class Exists(SubqueryPredicate):
pass


# Commands to interact with the databases or engines. For most of the command
# expressions we parse whatever comes after the command's name as a string.
class Command(Expression):
Expand Down Expand Up @@ -5583,6 +5579,10 @@ class Extract(Func):
arg_types = {"this": True, "expression": True}


class Exists(Func, SubqueryPredicate):
arg_types = {"this": True, "expression": False}


class Timestamp(Func):
arg_types = {"this": False, "zone": False, "with_tz": False}

Expand Down
27 changes: 27 additions & 0 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,3 +914,30 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
where.pop()

return expression


def any_to_exists(expression: exp.Expression) -> exp.Expression:
"""
Transform ANY operator to Spark's EXISTS

For example,
- Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
- Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)

Both ANY and EXISTS accept queries but currently only array expressions are supported for this
transformation
"""
if isinstance(expression, exp.Select):
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
for any in expression.find_all(exp.Any):
this = any.this
if isinstance(this, exp.Query):
continue

binop = any.parent
if isinstance(binop, exp.Binary):
lambda_arg = exp.to_identifier("x")
any.replace(lambda_arg)
lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))

return expression
11 changes: 11 additions & 0 deletions tests/dialects/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ def test_databricks(self):
},
)

self.validate_all(
"SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)",
read={
"databricks": "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)",
"spark": "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)",
},
write={
"spark": "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)",
},
)

# https://docs.databricks.com/sql/language-manual/functions/colonsign.html
def test_json(self):
self.validate_identity("SELECT c1:price, c1:price.foo, c1:price.bar[1]")
Expand Down
19 changes: 19 additions & 0 deletions tests/dialects/test_hive.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from tests.dialects.test_dialect import Validator

from sqlglot import exp


class TestHive(Validator):
dialect = "hive"
Expand Down Expand Up @@ -787,6 +789,23 @@ def test_hive(self):
},
)

self.validate_identity("EXISTS(col, x -> x % 2 = 0)").assert_is(exp.Exists)

self.validate_all(
"SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
read={
"hive": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
"spark2": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
"spark": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
"databricks": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
},
write={
"spark2": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
"spark": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
"databricks": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
},
)

def test_escapes(self) -> None:
self.validate_identity("'\n'", "'\\n'")
self.validate_identity("'\\n'")
Expand Down
10 changes: 10 additions & 0 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,16 @@ def test_postgres(self):
"duckdb": """SELECT JSON_EXISTS('{"a": [1,2,3]}', '$.a')""",
},
)
self.validate_all(
"WITH t AS (SELECT ARRAY[1, 2, 3] AS col) SELECT * FROM t WHERE 1 <= ANY(col) AND 2 = ANY(col)",
write={
"postgres": "WITH t AS (SELECT ARRAY[1, 2, 3] AS col) SELECT * FROM t WHERE 1 <= ANY(col) AND 2 = ANY(col)",
"hive": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)",
"spark2": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)",
"spark": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)",
"databricks": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)",
},
)

def test_ddl(self):
# Checks that user-defined types are parsed into DataType instead of Identifier
Expand Down
1 change: 0 additions & 1 deletion tests/fixtures/identity.sql
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ SELECT LEAD(a, 1) OVER (PARTITION BY a ORDER BY a) AS x
SELECT LEAD(a, 1, b) OVER (PARTITION BY a ORDER BY a) AS x
SELECT X((a, b) -> a + b, z -> z) AS x
SELECT X(a -> a + ("z" - 1))
SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)
SELECT test.* FROM test
SELECT a AS b FROM test
SELECT "a"."b" FROM "a"
Expand Down
Loading