From 5ff7c226213705e0d7c92e6c39e06ee7e0af914c Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Tue, 10 Sep 2024 03:03:59 +0000 Subject: [PATCH 1/3] Feat(athena): Improve DDL query support --- sqlglot/dialects/athena.py | 142 +++++++++++++++++++++++++++------- tests/dialects/test_athena.py | 82 ++++++++++++++++++-- 2 files changed, 192 insertions(+), 32 deletions(-) diff --git a/sqlglot/dialects/athena.py b/sqlglot/dialects/athena.py index 513f309fc9..483a9fbcc7 100644 --- a/sqlglot/dialects/athena.py +++ b/sqlglot/dialects/athena.py @@ -1,31 +1,132 @@ from __future__ import annotations +import typing as t + from sqlglot import exp from sqlglot.dialects.trino import Trino -from sqlglot.tokens import TokenType +from sqlglot.dialects.hive import Hive +from sqlglot.tokens import Token, TokenType + + +def _parse_as_hive(raw_tokens: t.List[Token]) -> bool: + if len(raw_tokens) > 0: + first_token = raw_tokens[0] + if first_token.token_type == TokenType.CREATE: + # CREATE is Hive (except for CREATE VIEW and CREATE TABLE... AS SELECT) + return not any(t.token_type in (TokenType.VIEW, TokenType.SELECT) for t in raw_tokens) + + # ALTER and DROP are Hive + return first_token.token_type in (TokenType.ALTER, TokenType.DROP) + return False + + +def _generate_as_hive(expression: exp.Expression) -> bool: + if isinstance(expression, exp.Create): + if expression.kind == "TABLE": + properties: t.Optional[exp.Properties] + if properties := expression.args.get("properties"): + if properties.find(exp.ExternalProperty): + return True # CREATE EXTERNAL TABLE is Hive + + if not isinstance(expression.expression, exp.Select): + return True # any CREATE TABLE other than CREATE TABLE AS SELECT is Hive + else: + return expression.kind != "VIEW" # CREATE VIEW is never Hive but CREATE SCHEMA etc is + + elif isinstance(expression, exp.Alter) or isinstance(expression, exp.Drop): + return True # all ALTER and DROP statements are Hive + + return False class Athena(Trino): + """ + Over the years, it looks like AWS has taken various execution engines, bolted on AWS-specific modifications and then + built the Athena service around them. + + Thus, Athena is not simply hosted Trino, it's more like a router that routes SQL queries to an execution engine depending + on the query type. + + As at 2024-09-10, assuming your Athena workgroup is configured to use "Athena engine version 3", the following engines exist: + + Hive: + - Accepts mostly the same syntax as Hadoop / Hive + - Uses backticks to quote identifiers + - Has a distinctive DDL syntax (around things like setting table properties, storage locations etc) that is different from Trino + - Used for *most* DDL, with some exceptions that get routed to the Trino engine instead: + - CREATE [EXTERNAL] TABLE (without AS SELECT) + - ALTER + - DROP + + Trino: + - Uses double quotes to quote identifiers + - Used for DDL operations that involve SELECT queries, eg: + - CREATE VIEW + - CREATE TABLE... AS SELECT + - Used for DML operations + - SELECT, INSERT, UPDATE, DELETE, MERGE + + The SQLGlot Athena dialect tries to identify which engine a query would be routed to and then uses the parser / generator for that engine + rather than trying to create a universal syntax that can handle both types. + """ + class Tokenizer(Trino.Tokenizer): + """ + The Tokenizer is flexible enough to tokenize queries across both the Hive and Trino engines + """ + IDENTIFIERS = ['"', "`"] KEYWORDS = { + **Hive.Tokenizer.KEYWORDS, **Trino.Tokenizer.KEYWORDS, "UNLOAD": TokenType.COMMAND, } + class HiveParser(Hive.Parser): + """ + Parse queries for the Athena Hive execution engine + """ + + pass + class Parser(Trino.Parser): + """ + Parse queries for the Athena Trino execution engine + """ + STATEMENT_PARSERS = { **Trino.Parser.STATEMENT_PARSERS, TokenType.USING: lambda self: self._parse_as_command(self._prev), } - class Generator(Trino.Generator): - WITH_PROPERTIES_PREFIX = "TBLPROPERTIES" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - PROPERTIES_LOCATION = { - **Trino.Generator.PROPERTIES_LOCATION, - exp.LocationProperty: exp.Properties.Location.POST_SCHEMA, - } + self._hive_parser = Athena.HiveParser(*args, **kwargs) + + def parse( + self, raw_tokens: t.List[Token], sql: t.Optional[str] = None + ) -> t.List[t.Optional[exp.Expression]]: + if _parse_as_hive(raw_tokens): + return self._hive_parser.parse(raw_tokens, sql) + + return super().parse(raw_tokens, sql) + + class HiveGenerator(Hive.Generator): + """ + Generating queries for the Athena Hive execution engine + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._identifier_start = "`" + self._identifier_end = "`" + + class Generator(Trino.Generator): + """ + Generate queries for the Athena Trino execution engine + """ TYPE_MAPPING = { **Trino.Generator.TYPE_MAPPING, @@ -37,23 +138,12 @@ class Generator(Trino.Generator): exp.FileFormatProperty: lambda self, e: f"'FORMAT'={self.sql(e, 'this')}", } + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._hive_generator = Athena.HiveGenerator(*args, **kwargs) + def generate(self, expression: exp.Expression, copy: bool = True) -> str: - if isinstance(expression, exp.DDL) or isinstance(expression, exp.Drop): - # Athena DDL uses backticks for quoting, unlike Athena DML which uses double quotes - # ...unless the DDL is CREATE VIEW, then it uses DML quoting, I guess because the view is based on a SELECT query - # ref: https://docs.aws.amazon.com/athena/latest/ug/reserved-words.html - # ref: https://docs.aws.amazon.com/athena/latest/ug/tables-databases-columns-names.html#table-names-that-include-numbers - if not (isinstance(expression, exp.Create) and expression.kind == "VIEW"): - self._identifier_start = "`" - self._identifier_end = "`" - - try: - return super().generate(expression, copy) - finally: - self._identifier_start = self.dialect.IDENTIFIER_START - self._identifier_end = self.dialect.IDENTIFIER_END - - def property_sql(self, expression: exp.Property) -> str: - return ( - f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" - ) + if _generate_as_hive(expression): + return self._hive_generator.generate(expression, copy) + + return super().generate(expression, copy) diff --git a/tests/dialects/test_athena.py b/tests/dialects/test_athena.py index 6ec870be05..182f0bcea2 100644 --- a/tests/dialects/test_athena.py +++ b/tests/dialects/test_athena.py @@ -24,16 +24,54 @@ def test_athena(self): check_command_warning=True, ) + def test_ddl(self): + # Hive-like, https://docs.aws.amazon.com/athena/latest/ug/create-table.html + self.validate_identity("CREATE EXTERNAL TABLE foo (id INT) COMMENT 'test comment'") + self.validate_identity( + "CREATE EXTERNAL TABLE foo (id INT, val STRING) CLUSTERED BY (id, val) INTO 10 BUCKETS" + ) + self.validate_identity( + "CREATE EXTERNAL TABLE foo (id INT, val STRING) STORED AS PARQUET LOCATION 's3://foo' TBLPROPERTIES ('has_encryped_data'='true', 'classification'='test')" + ) + self.validate_identity( + "CREATE EXTERNAL TABLE IF NOT EXISTS foo (a INT, b STRING) ROW FORMAT SERDE 'org.openx.data.jsonserde.JsonSerDe' WITH SERDEPROPERTIES ('case.insensitive'='FALSE') LOCATION 's3://table/path'" + ) + self.validate_identity( + """CREATE EXTERNAL TABLE x (y INT) ROW FORMAT SERDE 'serde' ROW FORMAT DELIMITED FIELDS TERMINATED BY '1' WITH SERDEPROPERTIES ('input.regex'='')""", + ) + self.validate_identity( + """CREATE EXTERNAL TABLE `my_table` (`a7` ARRAY) ROW FORMAT SERDE 'a' STORED AS INPUTFORMAT 'b' OUTPUTFORMAT 'c' LOCATION 'd' TBLPROPERTIES ('e'='f')""" + ) + + # Iceberg, https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-creating-tables.html + self.validate_identity( + "CREATE TABLE iceberg_table (`id` BIGINT, `data` STRING, category STRING) PARTITIONED BY (category, BUCKET(16, id)) LOCATION 's3://amzn-s3-demo-bucket/your-folder/' TBLPROPERTIES ('table_type'='ICEBERG', 'write_compression'='snappy')" + ) + + # CTAS goes to the Trino engine, where the table properties cant be encased in single quotes like they can for Hive + # ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties + self.validate_identity( + "CREATE TABLE foo WITH (table_type='ICEBERG', external_location='s3://foo/') AS SELECT * FROM a" + ) + def test_ddl_quoting(self): self.validate_identity("CREATE SCHEMA `foo`") self.validate_identity("CREATE SCHEMA foo") self.validate_identity("CREATE SCHEMA foo", write_sql="CREATE SCHEMA `foo`", identify=True) - self.validate_identity("CREATE EXTERNAL TABLE `foo` (`id` INTEGER) LOCATION 's3://foo/'") - self.validate_identity("CREATE EXTERNAL TABLE foo (id INTEGER) LOCATION 's3://foo/'") + self.validate_identity("CREATE EXTERNAL TABLE `foo` (`id` INT) LOCATION 's3://foo/'") + self.validate_identity("CREATE EXTERNAL TABLE foo (id INT) LOCATION 's3://foo/'") + self.validate_identity( + "CREATE EXTERNAL TABLE foo (id INT) LOCATION 's3://foo/'", + write_sql="CREATE EXTERNAL TABLE `foo` (`id` INT) LOCATION 's3://foo/'", + identify=True, + ) + + self.validate_identity("CREATE TABLE foo AS SELECT * FROM a") + self.validate_identity('CREATE TABLE "foo" AS SELECT * FROM "a"') self.validate_identity( - "CREATE EXTERNAL TABLE foo (id INTEGER) LOCATION 's3://foo/'", - write_sql="CREATE EXTERNAL TABLE `foo` (`id` INTEGER) LOCATION 's3://foo/'", + "CREATE TABLE `foo` AS SELECT * FROM `a`", + write_sql='CREATE TABLE "foo" AS SELECT * FROM "a"', identify=True, ) @@ -52,8 +90,8 @@ def test_ddl_quoting(self): # As a side effect of being able to parse both quote types, we can also fix the quoting on incorrectly quoted source queries self.validate_identity('CREATE SCHEMA "foo"', write_sql="CREATE SCHEMA `foo`") self.validate_identity( - 'CREATE EXTERNAL TABLE "foo" ("id" INTEGER) LOCATION \'s3://foo/\'', - write_sql="CREATE EXTERNAL TABLE `foo` (`id` INTEGER) LOCATION 's3://foo/'", + 'CREATE EXTERNAL TABLE "foo" ("id" INT) LOCATION \'s3://foo/\'', + write_sql="CREATE EXTERNAL TABLE `foo` (`id` INT) LOCATION 's3://foo/'", ) self.validate_identity('DROP TABLE "foo"', write_sql="DROP TABLE `foo`") self.validate_identity( @@ -61,6 +99,14 @@ def test_ddl_quoting(self): write_sql='CREATE VIEW "foo" AS SELECT "id" FROM "tbl"', ) + self.validate_identity( + 'ALTER TABLE "foo" ADD COLUMNS ("id" STRING)', + write_sql="ALTER TABLE `foo` ADD COLUMNS (`id` STRING)", + ) + self.validate_identity( + 'ALTER TABLE "foo" DROP COLUMN "id"', write_sql="ALTER TABLE `foo` DROP COLUMN `id`" + ) + def test_dml_quoting(self): self.validate_identity("SELECT a AS foo FROM tbl") self.validate_identity('SELECT "a" AS "foo" FROM "tbl"') @@ -69,3 +115,27 @@ def test_dml_quoting(self): write_sql='SELECT "a" AS "foo" FROM "tbl"', identify=True, ) + + self.validate_identity("INSERT INTO foo (id) VALUES (1)") + self.validate_identity('INSERT INTO "foo" ("id") VALUES (1)') + self.validate_identity( + 'INSERT INTO `foo` ("id") VALUES (1)', + write_sql='INSERT INTO "foo" ("id") VALUES (1)', + identify=True, + ) + + self.validate_identity("UPDATE foo SET id = 3 WHERE id = 7") + self.validate_identity('UPDATE "foo" SET "id" = 3 WHERE "id" = 7') + self.validate_identity( + 'UPDATE `foo` SET "id" = 3 WHERE `id` = 7', + write_sql='UPDATE "foo" SET "id" = 3 WHERE "id" = 7', + identify=True, + ) + + self.validate_identity("DELETE FROM foo WHERE id > 10") + self.validate_identity('DELETE FROM "foo" WHERE "id" > 10') + self.validate_identity( + "DELETE FROM `foo` WHERE `id` > 10", + write_sql='DELETE FROM "foo" WHERE "id" > 10', + identify=True, + ) From 1da1fd18258d24e67b45fcefcc19a4b3cd48bc13 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Tue, 10 Sep 2024 03:41:21 +0000 Subject: [PATCH 2/3] Remove usage of walrus operator --- sqlglot/dialects/athena.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sqlglot/dialects/athena.py b/sqlglot/dialects/athena.py index 483a9fbcc7..685ed4cf82 100644 --- a/sqlglot/dialects/athena.py +++ b/sqlglot/dialects/athena.py @@ -23,10 +23,9 @@ def _parse_as_hive(raw_tokens: t.List[Token]) -> bool: def _generate_as_hive(expression: exp.Expression) -> bool: if isinstance(expression, exp.Create): if expression.kind == "TABLE": - properties: t.Optional[exp.Properties] - if properties := expression.args.get("properties"): - if properties.find(exp.ExternalProperty): - return True # CREATE EXTERNAL TABLE is Hive + properties: t.Optional[exp.Properties] = expression.args.get("properties") + if properties and properties.find(exp.ExternalProperty): + return True # CREATE EXTERNAL TABLE is Hive if not isinstance(expression.expression, exp.Select): return True # any CREATE TABLE other than CREATE TABLE AS SELECT is Hive From 6038a604ad1e61e6944d883048a35a9e0e23958d Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Tue, 10 Sep 2024 21:21:46 +0000 Subject: [PATCH 3/3] PR feedback --- sqlglot/dialects/athena.py | 50 ++++------------------------------- tests/dialects/test_athena.py | 30 ++++++++++++++++++++- 2 files changed, 34 insertions(+), 46 deletions(-) diff --git a/sqlglot/dialects/athena.py b/sqlglot/dialects/athena.py index 685ed4cf82..40bbe7e6ea 100644 --- a/sqlglot/dialects/athena.py +++ b/sqlglot/dialects/athena.py @@ -5,19 +5,7 @@ from sqlglot import exp from sqlglot.dialects.trino import Trino from sqlglot.dialects.hive import Hive -from sqlglot.tokens import Token, TokenType - - -def _parse_as_hive(raw_tokens: t.List[Token]) -> bool: - if len(raw_tokens) > 0: - first_token = raw_tokens[0] - if first_token.token_type == TokenType.CREATE: - # CREATE is Hive (except for CREATE VIEW and CREATE TABLE... AS SELECT) - return not any(t.token_type in (TokenType.VIEW, TokenType.SELECT) for t in raw_tokens) - - # ALTER and DROP are Hive - return first_token.token_type in (TokenType.ALTER, TokenType.DROP) - return False +from sqlglot.tokens import TokenType def _generate_as_hive(expression: exp.Expression) -> bool: @@ -81,13 +69,6 @@ class Tokenizer(Trino.Tokenizer): "UNLOAD": TokenType.COMMAND, } - class HiveParser(Hive.Parser): - """ - Parse queries for the Athena Hive execution engine - """ - - pass - class Parser(Trino.Parser): """ Parse queries for the Athena Trino execution engine @@ -98,30 +79,6 @@ class Parser(Trino.Parser): TokenType.USING: lambda self: self._parse_as_command(self._prev), } - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._hive_parser = Athena.HiveParser(*args, **kwargs) - - def parse( - self, raw_tokens: t.List[Token], sql: t.Optional[str] = None - ) -> t.List[t.Optional[exp.Expression]]: - if _parse_as_hive(raw_tokens): - return self._hive_parser.parse(raw_tokens, sql) - - return super().parse(raw_tokens, sql) - - class HiveGenerator(Hive.Generator): - """ - Generating queries for the Athena Hive execution engine - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._identifier_start = "`" - self._identifier_end = "`" - class Generator(Trino.Generator): """ Generate queries for the Athena Trino execution engine @@ -139,7 +96,10 @@ class Generator(Trino.Generator): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._hive_generator = Athena.HiveGenerator(*args, **kwargs) + + hive_kwargs = {**kwargs, "dialect": "hive"} + + self._hive_generator = Hive.Generator(*args, **hive_kwargs) def generate(self, expression: exp.Expression, copy: bool = True) -> str: if _generate_as_hive(expression): diff --git a/tests/dialects/test_athena.py b/tests/dialects/test_athena.py index 182f0bcea2..bf5491418d 100644 --- a/tests/dialects/test_athena.py +++ b/tests/dialects/test_athena.py @@ -24,6 +24,17 @@ def test_athena(self): check_command_warning=True, ) + self.validate_identity( + "/* leading comment */CREATE SCHEMA foo", + write_sql="/* leading comment */ CREATE SCHEMA `foo`", + identify=True, + ) + self.validate_identity( + "/* leading comment */SELECT * FROM foo", + write_sql='/* leading comment */ SELECT * FROM "foo"', + identify=True, + ) + def test_ddl(self): # Hive-like, https://docs.aws.amazon.com/athena/latest/ug/create-table.html self.validate_identity("CREATE EXTERNAL TABLE foo (id INT) COMMENT 'test comment'") @@ -53,11 +64,13 @@ def test_ddl(self): self.validate_identity( "CREATE TABLE foo WITH (table_type='ICEBERG', external_location='s3://foo/') AS SELECT * FROM a" ) + self.validate_identity( + "CREATE TABLE foo AS WITH foo AS (SELECT a, b FROM bar) SELECT * FROM foo" + ) def test_ddl_quoting(self): self.validate_identity("CREATE SCHEMA `foo`") self.validate_identity("CREATE SCHEMA foo") - self.validate_identity("CREATE SCHEMA foo", write_sql="CREATE SCHEMA `foo`", identify=True) self.validate_identity("CREATE EXTERNAL TABLE `foo` (`id` INT) LOCATION 's3://foo/'") self.validate_identity("CREATE EXTERNAL TABLE foo (id INT) LOCATION 's3://foo/'") @@ -107,6 +120,14 @@ def test_ddl_quoting(self): 'ALTER TABLE "foo" DROP COLUMN "id"', write_sql="ALTER TABLE `foo` DROP COLUMN `id`" ) + self.validate_identity( + 'CREATE TABLE "foo" AS WITH "foo" AS (SELECT "a", "b" FROM "bar") SELECT * FROM "foo"' + ) + self.validate_identity( + 'CREATE TABLE `foo` AS WITH `foo` AS (SELECT "a", `b` FROM "bar") SELECT * FROM "foo"', + write_sql='CREATE TABLE "foo" AS WITH "foo" AS (SELECT "a", "b" FROM "bar") SELECT * FROM "foo"', + ) + def test_dml_quoting(self): self.validate_identity("SELECT a AS foo FROM tbl") self.validate_identity('SELECT "a" AS "foo" FROM "tbl"') @@ -139,3 +160,10 @@ def test_dml_quoting(self): write_sql='DELETE FROM "foo" WHERE "id" > 10', identify=True, ) + + self.validate_identity("WITH foo AS (SELECT a, b FROM bar) SELECT * FROM foo") + self.validate_identity( + "WITH foo AS (SELECT a, b FROM bar) SELECT * FROM foo", + write_sql='WITH "foo" AS (SELECT "a", "b" FROM "bar") SELECT * FROM "foo"', + identify=True, + )