Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(athena): Improve DDL query support #4099

Merged
merged 3 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 115 additions & 26 deletions sqlglot/dialects/athena.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,131 @@
from __future__ import annotations

import typing as t

from sqlglot import exp
from sqlglot.dialects.trino import Trino
from sqlglot.tokens import TokenType
from sqlglot.dialects.hive import Hive
from sqlglot.tokens import Token, TokenType


def _parse_as_hive(raw_tokens: t.List[Token]) -> bool:
if len(raw_tokens) > 0:
first_token = raw_tokens[0]
if first_token.token_type == TokenType.CREATE:
Copy link
Collaborator

@georgesittas georgesittas Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this approach is robust, for example what if you have leading CTEs? The first token would then be WITH. You could also have (redundant) leading semi-colons, etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was anticipating needing to tweak this. I forgot about CTE's, i'll add support for that.

How would you deal with redundant semicolons? Would you first filter the token stream to just the ones we care about and then do the checks?

Copy link
Collaborator Author

@erindru erindru Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if you have leading CTE's, the existing logic works and returns False. It doesnt need to check for them because only a SELECT query would have leading CTE's. All SELECT queries should use the Trino tokenizer, so returning False triggers the use of the Trino tokenizer.

If you have a CREATE TABLE .. AS WITH (...) SELECT query, a SELECT still appears in the tokens so this is still correctly detected as a CTAS and returns False which triggers the use of the Trino tokenizer

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you test redundant semicolons?

Trying to parse something like ; CREATE SCHEMA FOO; just throws a parse error

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if you have leading CTE's, the existing logic works and returns False. It doesnt need to check for them because only a SELECT query would have leading CTE's. All SELECT queries should use the Trino tokenizer, so returning False triggers the use of the Trino tokenizer.

Ah, interesting, I didn't realize that - that simplifies the problem then.

How do you test redundant semicolons?

Using parse:

>>> import sqlglot
>>> sqlglot.parse("; create schema foo ;;")
[None, Create(
  this=Table(
    db=Identifier(this=foo, quoted=False)),
  kind=SCHEMA), None]

Tbh this is an edge case but good to handle anyway, you can skip the leading tokens until you find the first non-semicolon or something.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I struggled to write a test for this and eventually realized that the issues I had that I thought were parsing issues were really tokenization issues. It appears the Trino parser is capable of handling all the Hive DDL as long as it is tokenized correctly.

So I was able to remove the delegation to the Hive parser entirely and just worry about the generation side. I may need to revisit this in future but I wasn't able to find a query to add to the tests that uses Hive syntax and also causes the Trino parser to fail

# CREATE is Hive (except for CREATE VIEW and CREATE TABLE... AS SELECT)
return not any(t.token_type in (TokenType.VIEW, TokenType.SELECT) for t in raw_tokens)

# ALTER and DROP are Hive
return first_token.token_type in (TokenType.ALTER, TokenType.DROP)
return False


def _generate_as_hive(expression: exp.Expression) -> bool:
if isinstance(expression, exp.Create):
if expression.kind == "TABLE":
properties: t.Optional[exp.Properties] = expression.args.get("properties")
if properties and properties.find(exp.ExternalProperty):
return True # CREATE EXTERNAL TABLE is Hive

if not isinstance(expression.expression, exp.Select):
return True # any CREATE TABLE other than CREATE TABLE AS SELECT is Hive
else:
return expression.kind != "VIEW" # CREATE VIEW is never Hive but CREATE SCHEMA etc is

elif isinstance(expression, exp.Alter) or isinstance(expression, exp.Drop):
return True # all ALTER and DROP statements are Hive

return False


class Athena(Trino):
"""
Over the years, it looks like AWS has taken various execution engines, bolted on AWS-specific modifications and then
built the Athena service around them.

Thus, Athena is not simply hosted Trino, it's more like a router that routes SQL queries to an execution engine depending
on the query type.

As at 2024-09-10, assuming your Athena workgroup is configured to use "Athena engine version 3", the following engines exist:

Hive:
- Accepts mostly the same syntax as Hadoop / Hive
- Uses backticks to quote identifiers
- Has a distinctive DDL syntax (around things like setting table properties, storage locations etc) that is different from Trino
- Used for *most* DDL, with some exceptions that get routed to the Trino engine instead:
- CREATE [EXTERNAL] TABLE (without AS SELECT)
- ALTER
- DROP

Trino:
- Uses double quotes to quote identifiers
- Used for DDL operations that involve SELECT queries, eg:
- CREATE VIEW
- CREATE TABLE... AS SELECT
- Used for DML operations
- SELECT, INSERT, UPDATE, DELETE, MERGE

The SQLGlot Athena dialect tries to identify which engine a query would be routed to and then uses the parser / generator for that engine
rather than trying to create a universal syntax that can handle both types.
"""

class Tokenizer(Trino.Tokenizer):
"""
The Tokenizer is flexible enough to tokenize queries across both the Hive and Trino engines
"""

IDENTIFIERS = ['"', "`"]
KEYWORDS = {
**Hive.Tokenizer.KEYWORDS,
**Trino.Tokenizer.KEYWORDS,
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
"UNLOAD": TokenType.COMMAND,
}

class HiveParser(Hive.Parser):
"""
Parse queries for the Athena Hive execution engine
"""

pass
georgesittas marked this conversation as resolved.
Show resolved Hide resolved

class Parser(Trino.Parser):
"""
Parse queries for the Athena Trino execution engine
"""

STATEMENT_PARSERS = {
**Trino.Parser.STATEMENT_PARSERS,
TokenType.USING: lambda self: self._parse_as_command(self._prev),
}

class Generator(Trino.Generator):
WITH_PROPERTIES_PREFIX = "TBLPROPERTIES"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

PROPERTIES_LOCATION = {
**Trino.Generator.PROPERTIES_LOCATION,
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA,
}
self._hive_parser = Athena.HiveParser(*args, **kwargs)

def parse(
self, raw_tokens: t.List[Token], sql: t.Optional[str] = None
) -> t.List[t.Optional[exp.Expression]]:
if _parse_as_hive(raw_tokens):
return self._hive_parser.parse(raw_tokens, sql)

return super().parse(raw_tokens, sql)

class HiveGenerator(Hive.Generator):
"""
Generating queries for the Athena Hive execution engine
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self._identifier_start = "`"
self._identifier_end = "`"
georgesittas marked this conversation as resolved.
Show resolved Hide resolved

class Generator(Trino.Generator):
"""
Generate queries for the Athena Trino execution engine
"""

TYPE_MAPPING = {
**Trino.Generator.TYPE_MAPPING,
Expand All @@ -37,23 +137,12 @@ class Generator(Trino.Generator):
exp.FileFormatProperty: lambda self, e: f"'FORMAT'={self.sql(e, 'this')}",
}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._hive_generator = Athena.HiveGenerator(*args, **kwargs)

def generate(self, expression: exp.Expression, copy: bool = True) -> str:
if isinstance(expression, exp.DDL) or isinstance(expression, exp.Drop):
# Athena DDL uses backticks for quoting, unlike Athena DML which uses double quotes
# ...unless the DDL is CREATE VIEW, then it uses DML quoting, I guess because the view is based on a SELECT query
# ref: https://docs.aws.amazon.com/athena/latest/ug/reserved-words.html
# ref: https://docs.aws.amazon.com/athena/latest/ug/tables-databases-columns-names.html#table-names-that-include-numbers
if not (isinstance(expression, exp.Create) and expression.kind == "VIEW"):
self._identifier_start = "`"
self._identifier_end = "`"

try:
return super().generate(expression, copy)
finally:
self._identifier_start = self.dialect.IDENTIFIER_START
self._identifier_end = self.dialect.IDENTIFIER_END

def property_sql(self, expression: exp.Property) -> str:
return (
f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}"
)
if _generate_as_hive(expression):
return self._hive_generator.generate(expression, copy)

return super().generate(expression, copy)
82 changes: 76 additions & 6 deletions tests/dialects/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,54 @@ def test_athena(self):
check_command_warning=True,
)

def test_ddl(self):
# Hive-like, https://docs.aws.amazon.com/athena/latest/ug/create-table.html
self.validate_identity("CREATE EXTERNAL TABLE foo (id INT) COMMENT 'test comment'")
self.validate_identity(
"CREATE EXTERNAL TABLE foo (id INT, val STRING) CLUSTERED BY (id, val) INTO 10 BUCKETS"
)
self.validate_identity(
"CREATE EXTERNAL TABLE foo (id INT, val STRING) STORED AS PARQUET LOCATION 's3://foo' TBLPROPERTIES ('has_encryped_data'='true', 'classification'='test')"
)
self.validate_identity(
"CREATE EXTERNAL TABLE IF NOT EXISTS foo (a INT, b STRING) ROW FORMAT SERDE 'org.openx.data.jsonserde.JsonSerDe' WITH SERDEPROPERTIES ('case.insensitive'='FALSE') LOCATION 's3://table/path'"
)
self.validate_identity(
"""CREATE EXTERNAL TABLE x (y INT) ROW FORMAT SERDE 'serde' ROW FORMAT DELIMITED FIELDS TERMINATED BY '1' WITH SERDEPROPERTIES ('input.regex'='')""",
)
self.validate_identity(
"""CREATE EXTERNAL TABLE `my_table` (`a7` ARRAY<DATE>) ROW FORMAT SERDE 'a' STORED AS INPUTFORMAT 'b' OUTPUTFORMAT 'c' LOCATION 'd' TBLPROPERTIES ('e'='f')"""
)

# Iceberg, https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-creating-tables.html
self.validate_identity(
"CREATE TABLE iceberg_table (`id` BIGINT, `data` STRING, category STRING) PARTITIONED BY (category, BUCKET(16, id)) LOCATION 's3://amzn-s3-demo-bucket/your-folder/' TBLPROPERTIES ('table_type'='ICEBERG', 'write_compression'='snappy')"
)

# CTAS goes to the Trino engine, where the table properties cant be encased in single quotes like they can for Hive
# ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties
self.validate_identity(
"CREATE TABLE foo WITH (table_type='ICEBERG', external_location='s3://foo/') AS SELECT * FROM a"
)

def test_ddl_quoting(self):
self.validate_identity("CREATE SCHEMA `foo`")
self.validate_identity("CREATE SCHEMA foo")
self.validate_identity("CREATE SCHEMA foo", write_sql="CREATE SCHEMA `foo`", identify=True)

self.validate_identity("CREATE EXTERNAL TABLE `foo` (`id` INTEGER) LOCATION 's3://foo/'")
self.validate_identity("CREATE EXTERNAL TABLE foo (id INTEGER) LOCATION 's3://foo/'")
self.validate_identity("CREATE EXTERNAL TABLE `foo` (`id` INT) LOCATION 's3://foo/'")
self.validate_identity("CREATE EXTERNAL TABLE foo (id INT) LOCATION 's3://foo/'")
self.validate_identity(
"CREATE EXTERNAL TABLE foo (id INT) LOCATION 's3://foo/'",
write_sql="CREATE EXTERNAL TABLE `foo` (`id` INT) LOCATION 's3://foo/'",
identify=True,
)

self.validate_identity("CREATE TABLE foo AS SELECT * FROM a")
self.validate_identity('CREATE TABLE "foo" AS SELECT * FROM "a"')
self.validate_identity(
"CREATE EXTERNAL TABLE foo (id INTEGER) LOCATION 's3://foo/'",
write_sql="CREATE EXTERNAL TABLE `foo` (`id` INTEGER) LOCATION 's3://foo/'",
"CREATE TABLE `foo` AS SELECT * FROM `a`",
write_sql='CREATE TABLE "foo" AS SELECT * FROM "a"',
identify=True,
)

Expand All @@ -52,15 +90,23 @@ def test_ddl_quoting(self):
# As a side effect of being able to parse both quote types, we can also fix the quoting on incorrectly quoted source queries
self.validate_identity('CREATE SCHEMA "foo"', write_sql="CREATE SCHEMA `foo`")
self.validate_identity(
'CREATE EXTERNAL TABLE "foo" ("id" INTEGER) LOCATION \'s3://foo/\'',
write_sql="CREATE EXTERNAL TABLE `foo` (`id` INTEGER) LOCATION 's3://foo/'",
'CREATE EXTERNAL TABLE "foo" ("id" INT) LOCATION \'s3://foo/\'',
write_sql="CREATE EXTERNAL TABLE `foo` (`id` INT) LOCATION 's3://foo/'",
)
self.validate_identity('DROP TABLE "foo"', write_sql="DROP TABLE `foo`")
self.validate_identity(
'CREATE VIEW `foo` AS SELECT "id" FROM `tbl`',
write_sql='CREATE VIEW "foo" AS SELECT "id" FROM "tbl"',
)

self.validate_identity(
'ALTER TABLE "foo" ADD COLUMNS ("id" STRING)',
write_sql="ALTER TABLE `foo` ADD COLUMNS (`id` STRING)",
)
self.validate_identity(
'ALTER TABLE "foo" DROP COLUMN "id"', write_sql="ALTER TABLE `foo` DROP COLUMN `id`"
)

def test_dml_quoting(self):
self.validate_identity("SELECT a AS foo FROM tbl")
self.validate_identity('SELECT "a" AS "foo" FROM "tbl"')
Expand All @@ -69,3 +115,27 @@ def test_dml_quoting(self):
write_sql='SELECT "a" AS "foo" FROM "tbl"',
identify=True,
)

self.validate_identity("INSERT INTO foo (id) VALUES (1)")
self.validate_identity('INSERT INTO "foo" ("id") VALUES (1)')
self.validate_identity(
'INSERT INTO `foo` ("id") VALUES (1)',
write_sql='INSERT INTO "foo" ("id") VALUES (1)',
identify=True,
)

self.validate_identity("UPDATE foo SET id = 3 WHERE id = 7")
self.validate_identity('UPDATE "foo" SET "id" = 3 WHERE "id" = 7')
self.validate_identity(
'UPDATE `foo` SET "id" = 3 WHERE `id` = 7',
write_sql='UPDATE "foo" SET "id" = 3 WHERE "id" = 7',
identify=True,
)

self.validate_identity("DELETE FROM foo WHERE id > 10")
self.validate_identity('DELETE FROM "foo" WHERE "id" > 10')
self.validate_identity(
"DELETE FROM `foo` WHERE `id` > 10",
write_sql='DELETE FROM "foo" WHERE "id" > 10',
identify=True,
)
Loading