Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(athena): Improve DDL query support #4099

Merged
merged 3 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 74 additions & 25 deletions sqlglot/dialects/athena.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,88 @@
from __future__ import annotations

import typing as t

from sqlglot import exp
from sqlglot.dialects.trino import Trino
from sqlglot.dialects.hive import Hive
from sqlglot.tokens import TokenType


def _generate_as_hive(expression: exp.Expression) -> bool:
if isinstance(expression, exp.Create):
if expression.kind == "TABLE":
properties: t.Optional[exp.Properties] = expression.args.get("properties")
if properties and properties.find(exp.ExternalProperty):
return True # CREATE EXTERNAL TABLE is Hive

if not isinstance(expression.expression, exp.Select):
return True # any CREATE TABLE other than CREATE TABLE AS SELECT is Hive
else:
return expression.kind != "VIEW" # CREATE VIEW is never Hive but CREATE SCHEMA etc is

elif isinstance(expression, exp.Alter) or isinstance(expression, exp.Drop):
return True # all ALTER and DROP statements are Hive

return False


class Athena(Trino):
"""
Over the years, it looks like AWS has taken various execution engines, bolted on AWS-specific modifications and then
built the Athena service around them.

Thus, Athena is not simply hosted Trino, it's more like a router that routes SQL queries to an execution engine depending
on the query type.

As at 2024-09-10, assuming your Athena workgroup is configured to use "Athena engine version 3", the following engines exist:

Hive:
- Accepts mostly the same syntax as Hadoop / Hive
- Uses backticks to quote identifiers
- Has a distinctive DDL syntax (around things like setting table properties, storage locations etc) that is different from Trino
- Used for *most* DDL, with some exceptions that get routed to the Trino engine instead:
- CREATE [EXTERNAL] TABLE (without AS SELECT)
- ALTER
- DROP

Trino:
- Uses double quotes to quote identifiers
- Used for DDL operations that involve SELECT queries, eg:
- CREATE VIEW
- CREATE TABLE... AS SELECT
- Used for DML operations
- SELECT, INSERT, UPDATE, DELETE, MERGE

The SQLGlot Athena dialect tries to identify which engine a query would be routed to and then uses the parser / generator for that engine
rather than trying to create a universal syntax that can handle both types.
"""

class Tokenizer(Trino.Tokenizer):
"""
The Tokenizer is flexible enough to tokenize queries across both the Hive and Trino engines
"""

IDENTIFIERS = ['"', "`"]
KEYWORDS = {
**Hive.Tokenizer.KEYWORDS,
**Trino.Tokenizer.KEYWORDS,
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
"UNLOAD": TokenType.COMMAND,
}

class Parser(Trino.Parser):
"""
Parse queries for the Athena Trino execution engine
"""

STATEMENT_PARSERS = {
**Trino.Parser.STATEMENT_PARSERS,
TokenType.USING: lambda self: self._parse_as_command(self._prev),
}

class Generator(Trino.Generator):
WITH_PROPERTIES_PREFIX = "TBLPROPERTIES"

PROPERTIES_LOCATION = {
**Trino.Generator.PROPERTIES_LOCATION,
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA,
}
"""
Generate queries for the Athena Trino execution engine
"""

TYPE_MAPPING = {
**Trino.Generator.TYPE_MAPPING,
Expand All @@ -37,23 +94,15 @@ class Generator(Trino.Generator):
exp.FileFormatProperty: lambda self, e: f"'FORMAT'={self.sql(e, 'this')}",
}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

hive_kwargs = {**kwargs, "dialect": "hive"}

self._hive_generator = Hive.Generator(*args, **hive_kwargs)

def generate(self, expression: exp.Expression, copy: bool = True) -> str:
if isinstance(expression, exp.DDL) or isinstance(expression, exp.Drop):
# Athena DDL uses backticks for quoting, unlike Athena DML which uses double quotes
# ...unless the DDL is CREATE VIEW, then it uses DML quoting, I guess because the view is based on a SELECT query
# ref: https://docs.aws.amazon.com/athena/latest/ug/reserved-words.html
# ref: https://docs.aws.amazon.com/athena/latest/ug/tables-databases-columns-names.html#table-names-that-include-numbers
if not (isinstance(expression, exp.Create) and expression.kind == "VIEW"):
self._identifier_start = "`"
self._identifier_end = "`"

try:
return super().generate(expression, copy)
finally:
self._identifier_start = self.dialect.IDENTIFIER_START
self._identifier_end = self.dialect.IDENTIFIER_END

def property_sql(self, expression: exp.Property) -> str:
return (
f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}"
)
if _generate_as_hive(expression):
return self._hive_generator.generate(expression, copy)

return super().generate(expression, copy)
112 changes: 105 additions & 7 deletions tests/dialects/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,67 @@ def test_athena(self):
check_command_warning=True,
)

self.validate_identity(
"/* leading comment */CREATE SCHEMA foo",
write_sql="/* leading comment */ CREATE SCHEMA `foo`",
identify=True,
)
self.validate_identity(
"/* leading comment */SELECT * FROM foo",
write_sql='/* leading comment */ SELECT * FROM "foo"',
identify=True,
)

def test_ddl(self):
# Hive-like, https://docs.aws.amazon.com/athena/latest/ug/create-table.html
self.validate_identity("CREATE EXTERNAL TABLE foo (id INT) COMMENT 'test comment'")
self.validate_identity(
"CREATE EXTERNAL TABLE foo (id INT, val STRING) CLUSTERED BY (id, val) INTO 10 BUCKETS"
)
self.validate_identity(
"CREATE EXTERNAL TABLE foo (id INT, val STRING) STORED AS PARQUET LOCATION 's3://foo' TBLPROPERTIES ('has_encryped_data'='true', 'classification'='test')"
)
self.validate_identity(
"CREATE EXTERNAL TABLE IF NOT EXISTS foo (a INT, b STRING) ROW FORMAT SERDE 'org.openx.data.jsonserde.JsonSerDe' WITH SERDEPROPERTIES ('case.insensitive'='FALSE') LOCATION 's3://table/path'"
)
self.validate_identity(
"""CREATE EXTERNAL TABLE x (y INT) ROW FORMAT SERDE 'serde' ROW FORMAT DELIMITED FIELDS TERMINATED BY '1' WITH SERDEPROPERTIES ('input.regex'='')""",
)
self.validate_identity(
"""CREATE EXTERNAL TABLE `my_table` (`a7` ARRAY<DATE>) ROW FORMAT SERDE 'a' STORED AS INPUTFORMAT 'b' OUTPUTFORMAT 'c' LOCATION 'd' TBLPROPERTIES ('e'='f')"""
)

# Iceberg, https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-creating-tables.html
self.validate_identity(
"CREATE TABLE iceberg_table (`id` BIGINT, `data` STRING, category STRING) PARTITIONED BY (category, BUCKET(16, id)) LOCATION 's3://amzn-s3-demo-bucket/your-folder/' TBLPROPERTIES ('table_type'='ICEBERG', 'write_compression'='snappy')"
)

# CTAS goes to the Trino engine, where the table properties cant be encased in single quotes like they can for Hive
# ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties
self.validate_identity(
"CREATE TABLE foo WITH (table_type='ICEBERG', external_location='s3://foo/') AS SELECT * FROM a"
)
self.validate_identity(
"CREATE TABLE foo AS WITH foo AS (SELECT a, b FROM bar) SELECT * FROM foo"
)

def test_ddl_quoting(self):
self.validate_identity("CREATE SCHEMA `foo`")
self.validate_identity("CREATE SCHEMA foo")
self.validate_identity("CREATE SCHEMA foo", write_sql="CREATE SCHEMA `foo`", identify=True)

self.validate_identity("CREATE EXTERNAL TABLE `foo` (`id` INTEGER) LOCATION 's3://foo/'")
self.validate_identity("CREATE EXTERNAL TABLE foo (id INTEGER) LOCATION 's3://foo/'")
self.validate_identity("CREATE EXTERNAL TABLE `foo` (`id` INT) LOCATION 's3://foo/'")
self.validate_identity("CREATE EXTERNAL TABLE foo (id INT) LOCATION 's3://foo/'")
self.validate_identity(
"CREATE EXTERNAL TABLE foo (id INT) LOCATION 's3://foo/'",
write_sql="CREATE EXTERNAL TABLE `foo` (`id` INT) LOCATION 's3://foo/'",
identify=True,
)

self.validate_identity("CREATE TABLE foo AS SELECT * FROM a")
self.validate_identity('CREATE TABLE "foo" AS SELECT * FROM "a"')
self.validate_identity(
"CREATE EXTERNAL TABLE foo (id INTEGER) LOCATION 's3://foo/'",
write_sql="CREATE EXTERNAL TABLE `foo` (`id` INTEGER) LOCATION 's3://foo/'",
"CREATE TABLE `foo` AS SELECT * FROM `a`",
write_sql='CREATE TABLE "foo" AS SELECT * FROM "a"',
identify=True,
)

Expand All @@ -52,15 +103,31 @@ def test_ddl_quoting(self):
# As a side effect of being able to parse both quote types, we can also fix the quoting on incorrectly quoted source queries
self.validate_identity('CREATE SCHEMA "foo"', write_sql="CREATE SCHEMA `foo`")
self.validate_identity(
'CREATE EXTERNAL TABLE "foo" ("id" INTEGER) LOCATION \'s3://foo/\'',
write_sql="CREATE EXTERNAL TABLE `foo` (`id` INTEGER) LOCATION 's3://foo/'",
'CREATE EXTERNAL TABLE "foo" ("id" INT) LOCATION \'s3://foo/\'',
write_sql="CREATE EXTERNAL TABLE `foo` (`id` INT) LOCATION 's3://foo/'",
)
self.validate_identity('DROP TABLE "foo"', write_sql="DROP TABLE `foo`")
self.validate_identity(
'CREATE VIEW `foo` AS SELECT "id" FROM `tbl`',
write_sql='CREATE VIEW "foo" AS SELECT "id" FROM "tbl"',
)

self.validate_identity(
'ALTER TABLE "foo" ADD COLUMNS ("id" STRING)',
write_sql="ALTER TABLE `foo` ADD COLUMNS (`id` STRING)",
)
self.validate_identity(
'ALTER TABLE "foo" DROP COLUMN "id"', write_sql="ALTER TABLE `foo` DROP COLUMN `id`"
)

self.validate_identity(
'CREATE TABLE "foo" AS WITH "foo" AS (SELECT "a", "b" FROM "bar") SELECT * FROM "foo"'
)
self.validate_identity(
'CREATE TABLE `foo` AS WITH `foo` AS (SELECT "a", `b` FROM "bar") SELECT * FROM "foo"',
write_sql='CREATE TABLE "foo" AS WITH "foo" AS (SELECT "a", "b" FROM "bar") SELECT * FROM "foo"',
)

def test_dml_quoting(self):
self.validate_identity("SELECT a AS foo FROM tbl")
self.validate_identity('SELECT "a" AS "foo" FROM "tbl"')
Expand All @@ -69,3 +136,34 @@ def test_dml_quoting(self):
write_sql='SELECT "a" AS "foo" FROM "tbl"',
identify=True,
)

self.validate_identity("INSERT INTO foo (id) VALUES (1)")
self.validate_identity('INSERT INTO "foo" ("id") VALUES (1)')
self.validate_identity(
'INSERT INTO `foo` ("id") VALUES (1)',
write_sql='INSERT INTO "foo" ("id") VALUES (1)',
identify=True,
)

self.validate_identity("UPDATE foo SET id = 3 WHERE id = 7")
self.validate_identity('UPDATE "foo" SET "id" = 3 WHERE "id" = 7')
self.validate_identity(
'UPDATE `foo` SET "id" = 3 WHERE `id` = 7',
write_sql='UPDATE "foo" SET "id" = 3 WHERE "id" = 7',
identify=True,
)

self.validate_identity("DELETE FROM foo WHERE id > 10")
self.validate_identity('DELETE FROM "foo" WHERE "id" > 10')
self.validate_identity(
"DELETE FROM `foo` WHERE `id` > 10",
write_sql='DELETE FROM "foo" WHERE "id" > 10',
identify=True,
)

self.validate_identity("WITH foo AS (SELECT a, b FROM bar) SELECT * FROM foo")
self.validate_identity(
"WITH foo AS (SELECT a, b FROM bar) SELECT * FROM foo",
write_sql='WITH "foo" AS (SELECT "a", "b" FROM "bar") SELECT * FROM "foo"',
identify=True,
)
Loading