Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(duckdb)!: Remove extra MAP bracket and ARRAY wrap #4712

Merged
merged 3 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import importlib
import logging
import typing as t
import sys

from enum import Enum, auto
from functools import reduce

Expand Down Expand Up @@ -103,6 +105,18 @@ class NormalizationStrategy(str, AutoName):
"""Always case-insensitive, regardless of quotes."""


class Version(int):
def __new__(cls, version_str: t.Optional[str], *args, **kwargs):
if version_str:
parts = version_str.split(".")
parts.extend(["0"] * (3 - len(parts)))
v = int("".join([p.zfill(3) for p in parts]))
else:
v = sys.maxsize

return super(Version, cls).__new__(cls, v)


class _Dialect(type):
_classes: t.Dict[str, t.Type[Dialect]] = {}

Expand Down Expand Up @@ -1002,6 +1016,10 @@ def parser(self, **opts) -> Parser:
def generator(self, **opts) -> Generator:
return self.generator_class(dialect=self, **opts)

@property
def version(self) -> Version:
return Version(self.settings.get("version", None))


DialectType = t.Union[str, Dialect, t.Type[Dialect], None]

Expand Down
9 changes: 8 additions & 1 deletion sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Dialect,
JSON_EXTRACT_TYPE,
NormalizationStrategy,
Version,
approx_count_distinct_sql,
arrow_json_extract_sql,
binary_from_function,
Expand Down Expand Up @@ -470,7 +471,9 @@ def _parse_bracket(
self, this: t.Optional[exp.Expression] = None
) -> t.Optional[exp.Expression]:
bracket = super()._parse_bracket(this)
if isinstance(bracket, exp.Bracket):

if self.dialect.version < Version("1.2.0") and isinstance(bracket, exp.Bracket):
# https://duckdb.org/2025/02/05/announcing-duckdb-120.html#breaking-changes
bracket.set("returns_list_for_maps", True)

return bracket
Expand Down Expand Up @@ -895,6 +898,10 @@ def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
return self.function_fallback_sql(expression)

def bracket_sql(self, expression: exp.Bracket) -> str:
if self.dialect.version >= Version("1.2"):
return super().bracket_sql(expression)

# https://duckdb.org/2025/02/05/announcing-duckdb-120.html#breaking-changes
this = expression.this
if isinstance(this, exp.Array):
this.replace(exp.paren(this))
Expand Down
12 changes: 12 additions & 0 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ def test_compare_dialects(self):
self.assertFalse(snowflake_class in {"bigquery", "redshift"})
self.assertFalse(snowflake_object in {"bigquery", "redshift"})

def test_compare_dialect_versions(self):
ddb_v1 = Dialect.get_or_raise("duckdb, version=1.0")
ddb_v1_2 = Dialect.get_or_raise("duckdb, foo=bar, version=1.0")
ddb_v2 = Dialect.get_or_raise("duckdb, version=2.2.4")
ddb_latest = Dialect.get_or_raise("duckdb")

self.assertTrue(ddb_latest.version > ddb_v2.version)
self.assertTrue(ddb_v1.version < ddb_v2.version)

self.assertTrue(ddb_v1.version == ddb_v1_2.version)
self.assertTrue(ddb_latest.version == Dialect.get_or_raise("duckdb").version)

def test_cast(self):
self.validate_all(
"CAST(a AS TEXT)",
Expand Down
6 changes: 3 additions & 3 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,15 +929,15 @@ def test_array_index(self):
)
self.validate_identity(
"""SELECT LIST_VALUE(1)[i]""",
"""SELECT ([1])[i]""",
"""SELECT [1][i]""",
)
self.validate_identity(
"""{'x': LIST_VALUE(1)[i]}""",
"""{'x': ([1])[i]}""",
"""{'x': [1][i]}""",
)
self.validate_identity(
"""SELECT LIST_APPLY(RANGE(1, 4), i -> {'f1': LIST_VALUE(1, 2, 3)[i], 'f2': LIST_VALUE(1, 2, 3)[i]})""",
"""SELECT LIST_APPLY(RANGE(1, 4), i -> {'f1': ([1, 2, 3])[i], 'f2': ([1, 2, 3])[i]})""",
"""SELECT LIST_APPLY(RANGE(1, 4), i -> {'f1': [1, 2, 3][i], 'f2': [1, 2, 3][i]})""",
)

self.assertEqual(
Expand Down
6 changes: 4 additions & 2 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ def test_spark(self):
write={
"databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
"spark": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
"duckdb": "SELECT ([1, 2, 3])[2]",
"duckdb": "SELECT [1, 2, 3][2]",
"duckdb, version=1.1.0": "SELECT ([1, 2, 3])[2]",
"presto": "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 2)",
},
)
Expand Down Expand Up @@ -352,7 +353,8 @@ def test_spark(self):
},
write={
"databricks": "SELECT TRY_ELEMENT_AT(MAP(1, 'a', 2, 'b'), 2)",
"duckdb": "SELECT (MAP([1, 2], ['a', 'b'])[2])[1]",
"duckdb": "SELECT MAP([1, 2], ['a', 'b'])[2]",
"duckdb, version=1.1.0": "SELECT (MAP([1, 2], ['a', 'b'])[2])[1]",
"spark": "SELECT TRY_ELEMENT_AT(MAP(1, 'a', 2, 'b'), 2)",
},
)
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def test_tsql(self):
parse_one("SELECT begin", read="tsql")

self.validate_identity("CREATE PROCEDURE test(@v1 INTEGER = 1, @v2 CHAR(1) = 'c')")
self.validate_identity("DECLARE @v1 AS INTEGER = 1, @v2 AS CHAR(1) = 'c')")
self.validate_identity("DECLARE @v1 AS INTEGER = 1, @v2 AS CHAR(1) = 'c'")

for output in ("OUT", "OUTPUT", "READ_ONLY"):
self.validate_identity(
Expand Down