Skip to content

Commit

Permalink
Refactor: implement decorator to easily mark args as unsupported
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas committed Sep 11, 2024
1 parent 67a9ad8 commit cf31875
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 29 deletions.
19 changes: 5 additions & 14 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from sqlglot import exp
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.generator import Generator, unsupported_args
from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses
from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path
from sqlglot.parser import Parser
Expand Down Expand Up @@ -957,9 +957,8 @@ def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
return lambda self, expression: self.func(name, *flatten(expression.args.values()))


@unsupported_args("accuracy")
def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
if expression.args.get("accuracy"):
self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
return self.func("APPROX_COUNT_DISTINCT", expression.this)


Expand Down Expand Up @@ -1359,11 +1358,8 @@ def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
)


@unsupported_args("position", "occurrence", "parameters")
def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
if bad_args:
self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")

group = expression.args.get("group")

# Do not render group if it's the default value for this dialect
Expand All @@ -1373,11 +1369,8 @@ def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
return self.func("REGEXP_EXTRACT", expression.this, expression.expression, group)


@unsupported_args("position", "occurrence", "modifiers")
def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
if bad_args:
self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")

return self.func(
"REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
)
Expand Down Expand Up @@ -1445,10 +1438,8 @@ def generatedasidentitycolumnconstraint_sql(


def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
@unsupported_args("count")
def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
if expression.args.get("count"):
self.unsupported(f"Only two arguments are supported in function {name}.")

return self.func(name, expression.this, expression.expression)

return _arg_max_or_min_sql
Expand Down
3 changes: 1 addition & 2 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ def _timediff_sql(self: DuckDB.Generator, expression: exp.TimeDiff) -> str:
return self.func("DATE_DIFF", unit_to_str(expression), expr, this)


@generator.unsupported_args(("expression", "DuckDB's ARRAY_SORT does not support a comparator."))
def _array_sort_sql(self: DuckDB.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("DuckDB ARRAY_SORT does not support a comparator")
return self.func("ARRAY_SORT", expression.this)


Expand Down
3 changes: 1 addition & 2 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,8 @@ def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str:
return self.func("TO_JSON", this, expression.args.get("options"))


@generator.unsupported_args(("expression", "Hive's SORT_ARRAY does not support a comparator."))
def _array_sort_sql(self: Hive.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("Hive SORT_ARRAY does not support a comparator")
return self.func("SORT_ARRAY", expression.this)


Expand Down
6 changes: 1 addition & 5 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,12 +1059,8 @@ def struct_sql(self, expression: exp.Struct) -> str:

return self.func("OBJECT_CONSTRUCT", *flatten(zip(keys, values)))

@generator.unsupported_args("weight", "accuracy")
def approxquantile_sql(self, expression: exp.ApproxQuantile) -> str:
if expression.args.get("weight") or expression.args.get("accuracy"):
self.unsupported(
"APPROX_PERCENTILE with weight and/or accuracy arguments are not supported in Snowflake"
)

return self.func("APPROX_PERCENTILE", expression.this, expression.args.get("quantile"))

def alterset_sql(self, expression: exp.AlterSet) -> str:
Expand Down
44 changes: 40 additions & 4 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import typing as t
from collections import defaultdict
from functools import reduce
from functools import reduce, wraps

from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
Expand All @@ -17,9 +17,47 @@
from sqlglot._typing import E
from sqlglot.dialects.dialect import DialectType

G = t.TypeVar("G", bound="Generator")
GeneratorMethod = t.Callable[[G, E], str]

logger = logging.getLogger("sqlglot")

ESCAPED_UNICODE_RE = re.compile(r"\\(\d+)")
UNSUPPORTED_TEMPLATE = "Argument '{}' is not supported for expression '{}' when targeting {}."


def unsupported_args(
*args: t.Union[str, t.Tuple[str, str]],
) -> t.Callable[[GeneratorMethod], GeneratorMethod]:
"""
Decorator that can be used to mark certain args of an `Expression` subclass as unsupported.
It expects a sequence of argument names or pairs of the form (argument_name, diagnostic_msg).
"""
diagnostic_by_arg: t.Dict[str, t.Optional[str]] = {}
for arg in args:
if isinstance(arg, str):
diagnostic_by_arg[arg] = None
else:
diagnostic_by_arg[arg[0]] = arg[1]

def decorator(func: GeneratorMethod) -> GeneratorMethod:
@wraps(func)
def _func(generator: G, expression: E) -> str:
expression_name = expression.__class__.__name__
dialect_name = generator.dialect.__class__.__name__

for arg_name, diagnostic in diagnostic_by_arg.items():
if arg_name in expression.args:
diagnostic = diagnostic or UNSUPPORTED_TEMPLATE.format(
arg_name, expression_name, dialect_name
)
generator.unsupported(diagnostic)

return func(generator, expression)

return _func

return decorator


class _Generator(type):
Expand Down Expand Up @@ -3594,10 +3632,8 @@ def merge_sql(self, expression: exp.Merge) -> str:
f"MERGE INTO {this}{table_alias}{sep}{using}{sep}{on}{sep}{expressions}{sep}{returning}",
)

@unsupported_args("format")
def tochar_sql(self, expression: exp.ToChar) -> str:
if expression.args.get("format"):
self.unsupported("Format argument unsupported for TO_CHAR/TO_VARCHAR function")

return self.sql(exp.cast(expression.this, exp.DataType.Type.TEXT))

def tonumber_sql(self, expression: exp.ToNumber) -> str:
Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,8 +822,8 @@ def test_duckdb(self):
"SELECT COALESCE(*COLUMNS(*)) FROM (SELECT NULL, 2, 3) AS t(a, b, c)"
)
self.validate_identity(
"SELECT id, STRUCT_PACK(*COLUMNS('m\d')) AS measurements FROM many_measurements",
"""SELECT id, {'_0': *COLUMNS('m\d')} AS measurements FROM many_measurements""",
"SELECT id, STRUCT_PACK(*COLUMNS('m\\d')) AS measurements FROM many_measurements",
"""SELECT id, {'_0': *COLUMNS('m\\d')} AS measurements FROM many_measurements""",
)

def test_array_index(self):
Expand Down

0 comments on commit cf31875

Please sign in to comment.