From 296104950f2e679aea37810a48eb490e170518d3 Mon Sep 17 00:00:00 2001 From: Jo <46752250+georgesittas@users.noreply.github.com> Date: Wed, 11 Sep 2024 19:40:34 +0300 Subject: [PATCH] Refactor: implement decorator to easily mark args as unsupported (#4111) * Refactor: implement decorator to easily mark args as unsupported * Use args.get(arg_name) instead of arg_name in args to avoid None issues --- sqlglot/dialects/dialect.py | 19 ++++----------- sqlglot/dialects/duckdb.py | 3 +-- sqlglot/dialects/hive.py | 3 +-- sqlglot/dialects/snowflake.py | 6 +---- sqlglot/generator.py | 44 +++++++++++++++++++++++++++++++---- tests/dialects/test_duckdb.py | 4 ++-- 6 files changed, 50 insertions(+), 29 deletions(-) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 21e52c9198..1d271f306d 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -7,7 +7,7 @@ from sqlglot import exp from sqlglot.errors import ParseError -from sqlglot.generator import Generator +from sqlglot.generator import Generator, unsupported_args from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path from sqlglot.parser import Parser @@ -957,9 +957,8 @@ def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: return lambda self, expression: self.func(name, *flatten(expression.args.values())) +@unsupported_args("accuracy") def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: - if expression.args.get("accuracy"): - self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") return self.func("APPROX_COUNT_DISTINCT", expression.this) @@ -1359,11 +1358,8 @@ def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: ) +@unsupported_args("position", "occurrence", "parameters") def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: - bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) - if bad_args: - self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") - group = expression.args.get("group") # Do not render group if it's the default value for this dialect @@ -1373,11 +1369,8 @@ def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: return self.func("REGEXP_EXTRACT", expression.this, expression.expression, group) +@unsupported_args("position", "occurrence", "modifiers") def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: - bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) - if bad_args: - self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") - return self.func( "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] ) @@ -1445,10 +1438,8 @@ def generatedasidentitycolumnconstraint_sql( def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: + @unsupported_args("count") def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: - if expression.args.get("count"): - self.unsupported(f"Only two arguments are supported in function {name}.") - return self.func(name, expression.this, expression.expression) return _arg_max_or_min_sql diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index e8f5837680..3932c11bc4 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -102,9 +102,8 @@ def _timediff_sql(self: DuckDB.Generator, expression: exp.TimeDiff) -> str: return self.func("DATE_DIFF", unit_to_str(expression), expr, this) +@generator.unsupported_args(("expression", "DuckDB's ARRAY_SORT does not support a comparator.")) def _array_sort_sql(self: DuckDB.Generator, expression: exp.ArraySort) -> str: - if expression.expression: - self.unsupported("DuckDB ARRAY_SORT does not support a comparator") return self.func("ARRAY_SORT", expression.this) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index ad6e92586c..d59e3a84c2 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -131,9 +131,8 @@ def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str: return self.func("TO_JSON", this, expression.args.get("options")) +@generator.unsupported_args(("expression", "Hive's SORT_ARRAY does not support a comparator.")) def _array_sort_sql(self: Hive.Generator, expression: exp.ArraySort) -> str: - if expression.expression: - self.unsupported("Hive SORT_ARRAY does not support a comparator") return self.func("SORT_ARRAY", expression.this) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index aec2e8aaf7..bc1a3154de 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -1059,12 +1059,8 @@ def struct_sql(self, expression: exp.Struct) -> str: return self.func("OBJECT_CONSTRUCT", *flatten(zip(keys, values))) + @generator.unsupported_args("weight", "accuracy") def approxquantile_sql(self, expression: exp.ApproxQuantile) -> str: - if expression.args.get("weight") or expression.args.get("accuracy"): - self.unsupported( - "APPROX_PERCENTILE with weight and/or accuracy arguments are not supported in Snowflake" - ) - return self.func("APPROX_PERCENTILE", expression.this, expression.args.get("quantile")) def alterset_sql(self, expression: exp.AlterSet) -> str: diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 26c20c80ad..307138cd4d 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -4,7 +4,7 @@ import re import typing as t from collections import defaultdict -from functools import reduce +from functools import reduce, wraps from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages @@ -17,9 +17,47 @@ from sqlglot._typing import E from sqlglot.dialects.dialect import DialectType + G = t.TypeVar("G", bound="Generator") + GeneratorMethod = t.Callable[[G, E], str] + logger = logging.getLogger("sqlglot") ESCAPED_UNICODE_RE = re.compile(r"\\(\d+)") +UNSUPPORTED_TEMPLATE = "Argument '{}' is not supported for expression '{}' when targeting {}." + + +def unsupported_args( + *args: t.Union[str, t.Tuple[str, str]], +) -> t.Callable[[GeneratorMethod], GeneratorMethod]: + """ + Decorator that can be used to mark certain args of an `Expression` subclass as unsupported. + It expects a sequence of argument names or pairs of the form (argument_name, diagnostic_msg). + """ + diagnostic_by_arg: t.Dict[str, t.Optional[str]] = {} + for arg in args: + if isinstance(arg, str): + diagnostic_by_arg[arg] = None + else: + diagnostic_by_arg[arg[0]] = arg[1] + + def decorator(func: GeneratorMethod) -> GeneratorMethod: + @wraps(func) + def _func(generator: G, expression: E) -> str: + expression_name = expression.__class__.__name__ + dialect_name = generator.dialect.__class__.__name__ + + for arg_name, diagnostic in diagnostic_by_arg.items(): + if expression.args.get(arg_name): + diagnostic = diagnostic or UNSUPPORTED_TEMPLATE.format( + arg_name, expression_name, dialect_name + ) + generator.unsupported(diagnostic) + + return func(generator, expression) + + return _func + + return decorator class _Generator(type): @@ -3594,10 +3632,8 @@ def merge_sql(self, expression: exp.Merge) -> str: f"MERGE INTO {this}{table_alias}{sep}{using}{sep}{on}{sep}{expressions}{sep}{returning}", ) + @unsupported_args("format") def tochar_sql(self, expression: exp.ToChar) -> str: - if expression.args.get("format"): - self.unsupported("Format argument unsupported for TO_CHAR/TO_VARCHAR function") - return self.sql(exp.cast(expression.this, exp.DataType.Type.TEXT)) def tonumber_sql(self, expression: exp.ToNumber) -> str: diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index fbf81dd8fb..2cc11d6c2b 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -822,8 +822,8 @@ def test_duckdb(self): "SELECT COALESCE(*COLUMNS(*)) FROM (SELECT NULL, 2, 3) AS t(a, b, c)" ) self.validate_identity( - "SELECT id, STRUCT_PACK(*COLUMNS('m\d')) AS measurements FROM many_measurements", - """SELECT id, {'_0': *COLUMNS('m\d')} AS measurements FROM many_measurements""", + "SELECT id, STRUCT_PACK(*COLUMNS('m\\d')) AS measurements FROM many_measurements", + """SELECT id, {'_0': *COLUMNS('m\\d')} AS measurements FROM many_measurements""", ) def test_array_index(self):