From e0372edcfafab4b17a7019c92769701f1b675647 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Thu, 19 Sep 2024 05:08:28 +0000 Subject: [PATCH 1/2] Feat: Athena adapter --- .circleci/continue_config.yml | 19 +- Makefile | 5 +- docs/guides/configuration.md | 1 + docs/integrations/engines/athena.md | 70 +++ docs/integrations/overview.md | 1 + mkdocs.yml | 1 + pytest.ini | 1 + setup.py | 1 + sqlmesh/core/config/connection.py | 60 +++ sqlmesh/core/engine_adapter/__init__.py | 2 + sqlmesh/core/engine_adapter/athena.py | 441 ++++++++++++++++++ sqlmesh/core/engine_adapter/base.py | 7 +- sqlmesh/core/engine_adapter/trino.py | 2 + tests/core/engine_adapter/config.yaml | 12 + tests/core/engine_adapter/test_athena.py | 295 ++++++++++++ tests/core/engine_adapter/test_integration.py | 105 ++++- 16 files changed, 1005 insertions(+), 18 deletions(-) create mode 100644 docs/integrations/engines/athena.md create mode 100644 sqlmesh/core/engine_adapter/athena.py create mode 100644 tests/core/engine_adapter/test_athena.py diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index 55543e728..7b282601c 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -326,15 +326,16 @@ workflows: matrix: parameters: engine: - - snowflake - - databricks - - redshift - - bigquery - - clickhouse-cloud - filters: - branches: - only: - - main + #- snowflake + #- databricks + #- redshift + #- bigquery + #- clickhouse-cloud + - athena + #filters: + # branches: + # only: + # - main - trigger_private_tests: requires: - style_and_slow_tests diff --git a/Makefile b/Makefile index e0e7390ba..21c921397 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ install-doc: pip3 install -r ./docs/requirements.txt install-engine-test: - pip3 install -e ".[dev,web,slack,mysql,postgres,databricks,redshift,bigquery,snowflake,trino,mssql,clickhouse]" + pip3 install -e ".[dev,web,slack,mysql,postgres,databricks,redshift,bigquery,snowflake,trino,mssql,clickhouse,athena]" install-pre-commit: pre-commit install @@ -209,3 +209,6 @@ redshift-test: guard-REDSHIFT_HOST guard-REDSHIFT_USER guard-REDSHIFT_PASSWORD g clickhouse-cloud-test: guard-CLICKHOUSE_CLOUD_HOST guard-CLICKHOUSE_CLOUD_USERNAME guard-CLICKHOUSE_CLOUD_PASSWORD engine-clickhouse-install pytest -n auto -x -m "clickhouse_cloud" --retries 3 --junitxml=test-results/junit-clickhouse-cloud.xml + +athena-test: guard-AWS_ACCESS_KEY_ID guard-AWS_SECRET_ACCESS_KEY guard-ATHENA_S3_WAREHOUSE_LOCATION engine-athena-install + pytest -n auto -x -m "athena" --retries 3 --junitxml=test-results/junit-athena.xml \ No newline at end of file diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index 9848b9aae..83c697f0b 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -483,6 +483,7 @@ Example snowflake connection configuration: These pages describe the connection configuration options for each execution engine. +* [Athena](../integrations/engines/athena.md) * [BigQuery](../integrations/engines/bigquery.md) * [Databricks](../integrations/engines/databricks.md) * [DuckDB](../integrations/engines/duckdb.md) diff --git a/docs/integrations/engines/athena.md b/docs/integrations/engines/athena.md new file mode 100644 index 000000000..f3efe6137 --- /dev/null +++ b/docs/integrations/engines/athena.md @@ -0,0 +1,70 @@ +# Athena + +## Installation + +``` +pip install "sqlmesh[athena]" +``` + +## Connection options + +### PyAthena connection options + +SQLMesh leverages the [PyAthena](https://github.com/laughingman7743/PyAthena) DBAPI driver to connect to Athena. Therefore, the connection options relate to the PyAthena connection options. +Note that PyAthena uses [boto3](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html) under the hood so you can also use [boto3 environment variables](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-environment-variables) for configuration. + +| Option | Description | Type | Required | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:| +| `type` | Engine type name - must be `athena` | string | Y | +| `aws_access_key_id` | The access key for your AWS user | string | N | +| `aws_secret_access_key` | The secret key for your AWS user | string | N | +| `role_arn` | The ARN of a role to assume once authenticated | string | N | +| `role_session_name` | The session name to use when assuming `role_arn` | string | N | +| `region_name` | The AWS region to use | string | N | +| `work_group` | The Athena [workgroup](https://docs.aws.amazon.com/athena/latest/ug/workgroups-manage-queries-control-costs.html) to send queries to | string | N | +| `s3_staging_dir` | The S3 location for Athena to write query results. Only required if not using `work_group` OR the configured `work_group` doesnt have a results location set | string | N | +| `schema_name` | The default schema to place objects in if a schema isnt specified. Defaults to `default` | string | N | +| `catalog_name` | The default catalog to place schemas in. Defaults to `AwsDataCatalog` | string | N | + +### SQLMesh connection options + +These options are specific to SQLMesh itself and are not passed to PyAthena + +| Option | Description | Type | Required | +|-------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------|----------| +| `s3_warehouse_location` | Set the base path in S3 where SQLMesh will place table data. Only required if the schemas dont have default locations set or you arent specifying the location in the model. See [S3 Locations](#s3-locations) below. | string | N | + +## Model properties + +The Athena adapter recognises the following model [physical_properties](../../concepts/models/overview.md#physical_properties): + +| Name | Description | Type | Default | +|-------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------|---------| +| `s3_base_location`| `s3://` base URI of where the snapshot tables for this model should be located. Overrides `s3_warehouse_location` if one is configured. | string | | +| `table_type` | Sets the [table_type](https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties) Athena uses when creating the table. Valid values are `hive` or `iceberg`. | string | `hive` | + + +## S3 Locations +When creating tables, Athena needs to know where in S3 the table data is located. You cannot issue a `CREATE TABLE` statement without specifying a `LOCATION` for the table data. + +If the schema you're creating the table under had a `LOCATION` set when it was created, Athena places the table in this location. Otherwise, it throws an error. + +Therefore, in order for SQLMesh to issue correct `CREATE TABLE` statements to Athena, there are a few strategies you can use to ensure the Athena tables are pointed to the correct S3 locations: + +- Manually pre-create the `sqlmesh__` physical schemas via `CREATE SCHEMA LOCATION 's3://base/location'`. Then when SQLMesh issues `CREATE TABLE` statements for tables within that schema, Athena knows where the data should go +- Set `s3_warehouse_location` in the connection config. SQLMesh will set the table `LOCATION` to be `//` when it issues a `CREATE TABLE` statement +- Set `s3_base_location` in the model `physical_properties`. SQLMesh will set the table `LOCATION` to be `/`. This takes precedence over the `s3_warehouse_location` set in the connection config or the `LOCATION` property on the target schema + +Note that if you opt to pre-create the schemas with a `LOCATION` already configured, you might want to look at [physical_schema_mapping](../../guides/configuration.md#physical-table-schemas) for better control of the schema names. + +## Limitations +Athena was initially designed to read data stored in S3 and to do so without changing that data. This means that it does not have good support for mutating tables. In particular, it will not delete data from Hive tables. + +Consequently, any SQLMesh model types that needs to delete or merge data from existing tables will not work. In addition, [forward only changes](../../concepts/plans.md#forward-only-change) that mutate the schemas of existing tables have a high chance of failure because Athena supports very limited schema modifications on Hive tables. + +However, Athena does support [Apache Iceberg](https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg.html) tables which allow a full range of operations. These can be used for more complex model types such as [`INCREMENTAL_BY_UNIQUE_KEY`](../../concepts/models/model_kinds.md#incremental_by_unique_key) and [`SCD_TYPE_2`](../../concepts/models/model_kinds.md#scd-type-2). + +To use an Iceberg table for a model, set `table_type='iceberg'` in the model [physical_properties](../../concepts/models/overview.md#physical_properties). + +In general, Iceberg tables offer the most flexibility and you'll run into the least SQLMesh limitations when using them. +However, they're a newer feature of Athena so you may run into Athena limitations that arent present in Hive tables, [particularly around supported data types](https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-supported-data-types.html). \ No newline at end of file diff --git a/docs/integrations/overview.md b/docs/integrations/overview.md index 7200fc6cf..f229137df 100644 --- a/docs/integrations/overview.md +++ b/docs/integrations/overview.md @@ -11,6 +11,7 @@ SQLMesh supports integrations with the following tools: ## Execution engines SQLMesh supports the following execution engines for running SQLMesh projects: +* [Athena](./engines/athena.md) * [BigQuery](./engines/bigquery.md) * [Databricks](./engines/databricks.md) * [DuckDB](./engines/duckdb.md) diff --git a/mkdocs.yml b/mkdocs.yml index 7cf1d64fa..9e032a104 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -66,6 +66,7 @@ nav: - integrations/dbt.md - integrations/github.md - Execution engines: + - integrations/engines/athena.md - integrations/engines/bigquery.md - integrations/engines/clickhouse.md - integrations/engines/databricks.md diff --git a/pytest.ini b/pytest.ini index f1731df7c..e016a6d0f 100644 --- a/pytest.ini +++ b/pytest.ini @@ -20,6 +20,7 @@ markers = spark_pyspark: test for Spark with PySpark dependency # Engine Adapters engine: test all engine adapters + athena: test for Athena bigquery: test for BigQuery clickhouse: test for Clickhouse (standalone mode) clickhouse_cluster: test for Clickhouse (cluster mode) diff --git a/setup.py b/setup.py index 7dfa80ffb..f3fb54f3b 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,7 @@ "sqlglot[rs]~=25.22.0", ], extras_require={ + "athena": ["PyAthena[Pandas]"], "bigquery": [ "google-cloud-bigquery[pandas]", "google-cloud-bigquery-storage", diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index bc6b10b9c..62960ae8e 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -1467,6 +1467,66 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: return {"compress": compress, "client_name": f"SQLMesh/{__version__}", **settings} +class AthenaConnectionConfig(ConnectionConfig): + # PyAthena connection options + aws_access_key_id: t.Optional[str] = None + aws_secret_access_key: t.Optional[str] = None + role_arn: t.Optional[str] = None + role_session_name: t.Optional[str] = None + region_name: t.Optional[str] = None + work_group: t.Optional[str] = None + s3_staging_dir: t.Optional[str] = None + schema_name: t.Optional[str] = None + catalog_name: t.Optional[str] = None + + # SQLMesh options + s3_warehouse_location: t.Optional[str] = None + concurrent_tasks: int = 4 + register_comments: bool = False # because Athena doesnt support comments in most cases + pre_ping: Literal[False] = False + + type_: Literal["athena"] = Field(alias="type", default="athena") + + @model_validator(mode="after") + @model_validator_v1_args + def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + work_group = values.get("work_group") + s3_staging_dir = values.get("s3_staging_dir") + + if not work_group and not s3_staging_dir: + raise ConfigError("At least one of work_group or s3_staging_dir must be set") + + return values + + @property + def _connection_kwargs_keys(self) -> t.Set[str]: + return { + "aws_access_key_id", + "aws_secret_access_key", + "role_arn", + "role_session_name", + "region_name", + "work_group", + "s3_staging_dir", + "schema_name", + "catalog_name", + } + + @property + def _engine_adapter(self) -> t.Type[EngineAdapter]: + return engine_adapter.AthenaEngineAdapter + + @property + def _extra_engine_config(self) -> t.Dict[str, t.Any]: + return {"s3_warehouse_location": self.s3_warehouse_location} + + @property + def _connection_factory(self) -> t.Callable: + from pyathena import connect # type: ignore + + return connect + + CONNECTION_CONFIG_TO_TYPE = { # Map all subclasses of ConnectionConfig to the value of their `type_` field. tpe.all_field_infos()["type_"].default: tpe diff --git a/sqlmesh/core/engine_adapter/__init__.py b/sqlmesh/core/engine_adapter/__init__.py index 1d2fe878a..25c45d2e1 100644 --- a/sqlmesh/core/engine_adapter/__init__.py +++ b/sqlmesh/core/engine_adapter/__init__.py @@ -17,6 +17,7 @@ from sqlmesh.core.engine_adapter.snowflake import SnowflakeEngineAdapter from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter from sqlmesh.core.engine_adapter.trino import TrinoEngineAdapter +from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter DIALECT_TO_ENGINE_ADAPTER = { "hive": SparkEngineAdapter, @@ -31,6 +32,7 @@ "mysql": MySQLEngineAdapter, "mssql": MSSQLEngineAdapter, "trino": TrinoEngineAdapter, + "athena": AthenaEngineAdapter, } DIALECT_ALIASES = { diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py new file mode 100644 index 000000000..3b3a3d96f --- /dev/null +++ b/sqlmesh/core/engine_adapter/athena.py @@ -0,0 +1,441 @@ +from __future__ import annotations +from functools import lru_cache +import typing as t +import logging +from sqlglot import exp +from sqlmesh.core.dialect import to_schema +from sqlmesh.core.engine_adapter.mixins import PandasNativeFetchDFSupportMixin +from sqlmesh.core.engine_adapter.trino import TrinoEngineAdapter +from sqlmesh.core.node import IntervalUnit +import os +from sqlmesh.utils.errors import SQLMeshError +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed +from sqlmesh.core.engine_adapter.shared import ( + CatalogSupport, + DataObject, + DataObjectType, + CommentCreationTable, + CommentCreationView, +) + +if t.TYPE_CHECKING: + from sqlmesh.core._typing import SchemaName, TableName + +logger = logging.getLogger(__name__) + + +def _ensure_valid_location(value: str) -> str: + if not value.startswith("s3://"): + raise SQLMeshError(f"Location '{value}' must be a s3:// URI") + + if not value.endswith("/"): + value += "/" + + # To avoid HIVE_METASTORE_ERROR: S3 resource path length must be less than or equal to 700. + if len(value) > 700: + raise SQLMeshError(f"Location '{value}' cannot be more than 700 characters") + + return value + + +# Athena's interaction with the Glue Data Catalog is a bit racey when lots of DDL queries are being fired at it, eg in integration tests +# - a DROP query will fail and then the same query will succeed a few seconds later +# - a CREATE immediately followed by a DESCRIBE will fail in the DESCRIBE but will succeed a few seconds later +def metadata_retry(func: t.Callable) -> t.Callable: + try: + from pyathena.error import OperationalError # type: ignore + + return retry( + retry=retry_if_exception_type(OperationalError), + stop=stop_after_attempt(3), + wait=wait_fixed(5), + reraise=True, + )(func) + except ImportError: + # if pyathena isnt installed, this is a no-op + return func + + +class AthenaEngineAdapter(PandasNativeFetchDFSupportMixin): + DIALECT = "athena" + SUPPORTS_TRANSACTIONS = False + SUPPORTS_REPLACE_TABLE = False + # Athena has the concept of catalogs but no notion of current_catalog or setting the current catalog + CATALOG_SUPPORT = CatalogSupport.UNSUPPORTED + # Athena's support for table and column comments is too patchy to consider "supported" + # Hive tables: Table + Column comments are supported + # Iceberg tables: Column comments only + # CTAS, Views: No comment support at all + COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED + COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED + SCHEMA_DIFFER = TrinoEngineAdapter.SCHEMA_DIFFER + + def __init__( + self, *args: t.Any, s3_warehouse_location: t.Optional[str] = None, **kwargs: t.Any + ): + # Need to pass s3_warehouse_location to the superclass so that it goes into _extra_config + # which means that EngineAdapter.with_log_level() keeps this property when it makes a clone + super().__init__(*args, s3_warehouse_location=s3_warehouse_location, **kwargs) + self.s3_warehouse_location = s3_warehouse_location + + @property + def s3_warehouse_location(self) -> t.Optional[str]: + return self._s3_warehouse_location + + @s3_warehouse_location.setter + def s3_warehouse_location(self, value: t.Optional[str]) -> None: + if value: + value = _ensure_valid_location(value) + self._s3_warehouse_location = value + + def create_state_table( + self, + table_name: str, + columns_to_types: t.Dict[str, exp.DataType], + primary_key: t.Optional[t.Tuple[str, ...]] = None, + ) -> None: + self.create_table( + table_name, + columns_to_types, + primary_key=primary_key, + table_properties={ + # it's painfully slow, but it works + "table_type": exp.Literal.string("iceberg") + }, + ) + + def _get_data_objects( + self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None + ) -> t.List[DataObject]: + """ + Returns all the data objects that exist in the given schema and optionally catalog. + """ + schema_name = to_schema(schema_name) + schema = schema_name.db + query = ( + exp.select( + exp.case() + .when( + # 'awsdatacatalog' is the default catalog that is invisible for all intents and purposes + # it just happens to show up in information_schema queries + exp.column("table_catalog", table="t").eq("awsdatacatalog"), + exp.Null(), + ) + .else_(exp.column("table_catalog")) + .as_("catalog"), + exp.column("table_schema", table="t").as_("schema"), + exp.column("table_name", table="t").as_("name"), + exp.case() + .when( + exp.column("table_type", table="t").eq("BASE TABLE"), + exp.Literal.string("table"), + ) + .else_(exp.column("table_type", table="t")) + .as_("type"), + ) + .from_(exp.to_table("information_schema.tables", alias="t")) + .where( + exp.and_( + exp.column("table_schema", table="t").eq(schema), + ) + ) + ) + if object_names: + query = query.where(exp.column("table_name", table="t").isin(*object_names)) + + df = self.fetchdf(query) + + return [ + DataObject( + catalog=row.catalog, # type: ignore + schema=row.schema, # type: ignore + name=row.name, # type: ignore + type=DataObjectType.from_str(row.type), # type: ignore + ) + for row in df.itertuples() + ] + + @metadata_retry + def columns( + self, table_name: TableName, include_pseudo_columns: bool = False + ) -> t.Dict[str, exp.DataType]: + table = exp.to_table(table_name) + query = ( + exp.select("column_name", "data_type") + .from_("information_schema.columns") + .where(exp.column("table_schema").eq(table.db), exp.column("table_name").eq(table.name)) + .order_by("ordinal_position") + ) + result = self.fetchdf(query, quote_identifiers=True) + return { + str(r.column_name): exp.DataType.build(str(r.data_type)) + for r in result.itertuples(index=False) + } + + def _create_schema( + self, + schema_name: SchemaName, + ignore_if_exists: bool, + warn_on_error: bool, + properties: t.List[exp.Expression], + kind: str, + ) -> None: + if location := self._table_location(table_properties=None, table=exp.to_table(schema_name)): + # don't add extra LocationProperty's if one already exists + if not any(p for p in properties if isinstance(p, exp.LocationProperty)): + properties.append(location) + + return super()._create_schema( + schema_name=schema_name, + ignore_if_exists=ignore_if_exists, + warn_on_error=warn_on_error, + properties=properties, + kind=kind, + ) + + def _build_create_table_exp( + self, + table_name_or_schema: t.Union[exp.Schema, TableName], + expression: t.Optional[exp.Expression], + exists: bool = True, + replace: bool = False, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + table_kind: t.Optional[str] = None, + partitioned_by: t.Optional[t.List[exp.Expression]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + **kwargs: t.Any, + ) -> exp.Create: + exists = False if replace else exists + + table: exp.Table + if isinstance(table_name_or_schema, str): + table = exp.to_table(table_name_or_schema) + elif isinstance(table_name_or_schema, exp.Schema): + table = table_name_or_schema.this + else: + table = table_name_or_schema + + properties = self._build_table_properties_exp( + table=table, + expression=expression, + columns_to_types=columns_to_types, + partitioned_by=partitioned_by, + table_properties=table_properties, + table_description=table_description, + table_kind=table_kind, + **kwargs, + ) + + is_hive = self._table_type(table_properties) == "hive" + + # Filter any PARTITIONED BY properties from the main column list since they cant be specified in both places + # ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html + if is_hive and partitioned_by and isinstance(table_name_or_schema, exp.Schema): + partitioned_by_column_names = {e.name for e in partitioned_by} + filtered_expressions = [ + e + for e in table_name_or_schema.expressions + if isinstance(e, exp.ColumnDef) and e.this.name not in partitioned_by_column_names + ] + table_name_or_schema.args["expressions"] = filtered_expressions + + return exp.Create( + this=table_name_or_schema, + kind=table_kind or "TABLE", + replace=replace, + exists=exists, + expression=expression, + properties=properties, + ) + + def _build_table_properties_exp( + self, + catalog_name: t.Optional[str] = None, + storage_format: t.Optional[str] = None, + partitioned_by: t.Optional[t.List[exp.Expression]] = None, + partition_interval_unit: t.Optional[IntervalUnit] = None, + clustered_by: t.Optional[t.List[str]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + table_kind: t.Optional[str] = None, + table: t.Optional[exp.Table] = None, + expression: t.Optional[exp.Expression] = None, + **kwargs: t.Any, + ) -> t.Optional[exp.Properties]: + properties: t.List[exp.Expression] = [] + table_properties = table_properties or {} + + is_hive = self._table_type(table_properties) == "hive" + is_iceberg = not is_hive + + if is_hive and not expression: + # Hive tables are CREATE EXTERNAL TABLE, Iceberg tables are CREATE TABLE + # Unless it's a CTAS, those are always CREATE TABLE + properties.append(exp.ExternalProperty()) + + if table_description: + properties.append(exp.SchemaCommentProperty(this=exp.Literal.string(table_description))) + + if partitioned_by: + schema_expressions: t.List[exp.Expression] = [] + if is_hive and columns_to_types: + # For Hive-style tables, you cannot include the partitioned by columns in the main set of columns + # In the PARTITIONED BY expression, you also cant just include the column names, you need to include the data type as well + # ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html + for match_name, match_dtype in self._find_matching_columns( + partitioned_by, columns_to_types + ): + column_def = exp.ColumnDef(this=exp.to_identifier(match_name), kind=match_dtype) + schema_expressions.append(column_def) + else: + schema_expressions = partitioned_by + + properties.append( + exp.PartitionedByProperty(this=exp.Schema(expressions=schema_expressions)) + ) + + if clustered_by: + # Athena itself supports CLUSTERED BY, via the syntax CLUSTERED BY (col) INTO BUCKETS + # However, SQLMesh is more closely aligned with BigQuery's notion of clustering and + # defines `clustered_by` as a List[str] with no way of indicating the number of buckets + # + # Athena's concept of CLUSTER BY is more like Iceberg's `bucket(, col)` partition transform + logging.warning("clustered_by is not supported in the Athena adapter at this time") + + if storage_format: + if is_iceberg: + # TBLPROPERTIES('format'='parquet') + table_properties["format"] = exp.Literal.string(storage_format) + else: + # STORED AS PARQUET + properties.append(exp.FileFormatProperty(this=storage_format)) + + if table and (location := self._table_location(table_properties, table)): + properties.append(location) + + for name, value in table_properties.items(): + properties.append(exp.Property(this=exp.var(name), value=value)) + + if is_iceberg and expression: + # To make a CTAS expression persist as iceberg, alongside setting `table_type=iceberg` (which the user has already + # supplied in physical_properties and is thus set above), you also need to set: + # - is_external=false + # - table_location='s3://' + # ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties + properties.append(exp.Property(this=exp.var("is_external"), value="false")) + + if properties: + return exp.Properties(expressions=properties) + + return None + + def _truncate_table(self, table_name: TableName) -> None: + table = exp.to_table(table_name) + # Athena doesnt support TRUNCATE TABLE. The closest thing is DELETE FROM but it only works on Iceberg + self.execute(f"DELETE FROM {table.sql(dialect=self.dialect, identify=True)}") + + def _table_type( + self, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None + ) -> t.Union[t.Literal["hive"], t.Literal["iceberg"]]: + """ + Use the user-specified table_properties to figure out of this is a Hive or an Iceberg table + """ + # if table_type is not defined or is not set to "iceberg", this is a Hive table + if table_properties and (table_type := table_properties.get("table_type", None)): + if "iceberg" in table_type.sql(dialect=self.dialect).lower(): + return "iceberg" + return "hive" + + @lru_cache() + def _query_table_type( + self, table_name: TableName + ) -> t.Union[t.Literal["hive"], t.Literal["iceberg"]]: + """ + Hit the DB to check if this is a Hive or an Iceberg table + """ + table_name = exp.to_table(table_name) + # Note: SHOW TBLPROPERTIES gets parsed by SQLGlot as an exp.Command anyway so we just use a string here + # This also means we need to use dialect="hive" instead of dialect="athena" so that the identifiers get the correct quoting (backticks) + for row in self.fetchall(f"SHOW TBLPROPERTIES {table_name.sql(dialect='hive')}"): + # This query returns a single column with values like 'EXTERNAL\tTRUE' + row_lower = row[0].lower() + if "external" in row_lower and "true" in row_lower: + return "hive" + return "iceberg" + + def _table_location( + self, + table_properties: t.Optional[t.Dict[str, exp.Expression]], + table: exp.Table, + ) -> t.Optional[exp.LocationProperty]: + base_uri: str + + # If the user has manually specified a `s3_base_location`, use it + if table_properties and "s3_base_location" in table_properties: + s3_base_location_property = table_properties.pop( + "s3_base_location" + ) # pop because it's handled differently and we dont want it to end up in the TBLPROPERTIES clause + if isinstance(s3_base_location_property, exp.Expression): + base_uri = s3_base_location_property.name + else: + base_uri = s3_base_location_property + + elif self.s3_warehouse_location: + # If the user has set `s3_warehouse_location` in the connection config, the base URI is /// + catalog_name = table.catalog if hasattr(table, "catalog") else None + schema_name = table.db if hasattr(table, "db") else None + base_uri = os.path.join( + self.s3_warehouse_location, catalog_name or "", schema_name or "" + ) + else: + # Assume the user has set a default location for this schema in the metastore + return None + + table_name = table.name if hasattr(table, "name") else None + full_uri = _ensure_valid_location(os.path.join(base_uri, table_name or "")) + + return exp.LocationProperty(this=exp.Literal.string(full_uri)) + + def _find_matching_columns( + self, partitioned_by: t.List[exp.Expression], columns_to_types: t.Dict[str, exp.DataType] + ) -> t.List[t.Tuple[str, exp.DataType]]: + matches = [] + for col in partitioned_by: + # TODO: do we care about normalization? + key = col.name + if isinstance(col, exp.Column) and (match_dtype := columns_to_types.get(key)): + matches.append((key, match_dtype)) + return matches + + def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None: + table_type = self._query_table_type(table_name) + + # If Iceberg, DELETE operations work as expected + if table_type == "iceberg": + return super().delete_from(table_name, where) + + # If Hive, DELETE is an error + if table_type == "hive": + # However, if the table is empty, we can make DELETE a no-op + # This simplifies a bunch of calling code that just assumes DELETE works (which to be fair is a reasonable assumption since it does for every other engine) + empty_check = ( + exp.select("*").from_(table_name).limit(1) + ) # deliberately not count(*) because we want the engine to stop as soon as it finds a record + if len(self.fetchall(empty_check)) > 0: + # TODO: in future, if SQLMesh adds support for explicit partition management, we may + # be able to covert the DELETE query into an ALTER TABLE DROP PARTITION assuming the WHERE clause fully covers the partition bounds + raise SQLMeshError("Cannot delete from non-empty Hive table") + + return None + + @metadata_retry + def _drop_object( + self, + name: TableName | SchemaName, + exists: bool = True, + kind: str = "TABLE", + **drop_args: t.Any, + ) -> None: + return super()._drop_object(name, exists=exists, kind=kind, **drop_args) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 12d7cabbe..cf2e008a5 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -1034,8 +1034,9 @@ def create_schema( schema_name: SchemaName, ignore_if_exists: bool = True, warn_on_error: bool = True, - properties: t.List[exp.Expression] = [], + properties: t.Optional[t.List[exp.Expression]] = None, ) -> None: + properties = properties or [] return self._create_schema( schema_name=schema_name, ignore_if_exists=ignore_if_exists, @@ -1325,6 +1326,7 @@ def scd_type_2_by_time( table_description=table_description, column_descriptions=column_descriptions, truncate=truncate, + **kwargs, ) def scd_type_2_by_column( @@ -1358,6 +1360,7 @@ def scd_type_2_by_column( table_description=table_description, column_descriptions=column_descriptions, truncate=truncate, + **kwargs, ) def _scd_type_2( @@ -1377,6 +1380,7 @@ def _scd_type_2( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, + **kwargs: t.Any, ) -> None: def remove_managed_columns( cols_to_types: t.Dict[str, exp.DataType], @@ -1730,6 +1734,7 @@ def remove_managed_columns( columns_to_types=columns_to_types, table_description=table_description, column_descriptions=column_descriptions, + **kwargs, ) def merge( diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index 822465283..61fe20c35 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -223,6 +223,7 @@ def _scd_type_2( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, + **kwargs: t.Any, ) -> None: if columns_to_types and self.current_catalog_type == "delta_lake": columns_to_types = self._to_delta_ts(columns_to_types) @@ -243,6 +244,7 @@ def _scd_type_2( table_description, column_descriptions, truncate, + **kwargs, ) # delta_lake only supports two timestamp data types. This method converts other diff --git a/tests/core/engine_adapter/config.yaml b/tests/core/engine_adapter/config.yaml index 30d4f5bf8..bab81e2ce 100644 --- a/tests/core/engine_adapter/config.yaml +++ b/tests/core/engine_adapter/config.yaml @@ -153,5 +153,17 @@ gateways: state_connection: type: duckdb + inttest_athena: + connection: + type: athena + aws_access_key_id: {{ env_var("AWS_ACCESS_KEY_ID") }} + aws_secret_access_key: {{ env_var("AWS_SECRET_ACCESS_KEY") }} + region_name: {{ env_var("AWS_REGION") }} + work_group: {{ env_var("ATHENA_WORK_GROUP", "primary") }} + s3_warehouse_location: {{ env_var("ATHENA_S3_WAREHOUSE_LOCATION") }} + concurrent_tasks: 1 + state_connection: + type: duckdb + model_defaults: dialect: duckdb diff --git a/tests/core/engine_adapter/test_athena.py b/tests/core/engine_adapter/test_athena.py new file mode 100644 index 000000000..ed9785158 --- /dev/null +++ b/tests/core/engine_adapter/test_athena.py @@ -0,0 +1,295 @@ +import typing as t +import pytest +from pytest_mock import MockerFixture +import pandas as pd + +from sqlglot import exp, parse_one +import sqlmesh.core.dialect as d +from sqlmesh.core.engine_adapter import AthenaEngineAdapter +from sqlmesh.core.model import load_sql_based_model +from sqlmesh.core.model.definition import SqlModel + +from tests.core.engine_adapter import to_sql_calls + +pytestmark = [pytest.mark.athena, pytest.mark.engine] + + +@pytest.fixture +def adapter(make_mocked_engine_adapter: t.Callable) -> AthenaEngineAdapter: + return make_mocked_engine_adapter(AthenaEngineAdapter) + + +@pytest.mark.parametrize( + "config_s3_warehouse_location,table_properties,table,expected_location", + [ + # No s3_warehouse_location in config + (None, None, exp.to_table("schema.table"), None), + (None, {}, exp.to_table("schema.table"), None), + ( + None, + {"s3_base_location": exp.Literal.string("s3://some/location/")}, + exp.to_table("schema.table"), + "s3://some/location/table/", + ), + (None, None, exp.Table(db=exp.Identifier(this="test")), None), + # Location set to bucket + ("s3://bucket", None, exp.to_table("schema.table"), "s3://bucket/schema/table/"), + ("s3://bucket", {}, exp.to_table("schema.table"), "s3://bucket/schema/table/"), + ("s3://bucket", None, exp.to_table("schema.table"), "s3://bucket/schema/table/"), + ( + "s3://bucket", + {"s3_base_location": exp.Literal.string("s3://some/location/")}, + exp.to_table("schema.table"), + "s3://some/location/table/", + ), + ("s3://bucket", {}, exp.Table(db=exp.Identifier(this="test")), "s3://bucket/test/"), + # Location set to bucket with prefix + ( + "s3://bucket/subpath/", + None, + exp.to_table("schema.table"), + "s3://bucket/subpath/schema/table/", + ), + ("s3://bucket/subpath/", None, exp.to_table("table"), "s3://bucket/subpath/table/"), + ( + "s3://bucket/subpath/", + None, + exp.to_table("catalog.schema.table"), + "s3://bucket/subpath/catalog/schema/table/", + ), + ( + "s3://bucket/subpath/", + None, + exp.Table(db=exp.Identifier(this="test")), + "s3://bucket/subpath/test/", + ), + ], +) +def test_table_location( + adapter: AthenaEngineAdapter, + config_s3_warehouse_location: t.Optional[str], + table_properties: t.Optional[t.Dict[str, exp.Expression]], + table: exp.Table, + expected_location: t.Optional[str], +) -> None: + adapter.s3_warehouse_location = config_s3_warehouse_location + location = adapter._table_location(table_properties, table) + final_location = None + + if location and expected_location: + final_location = ( + location.this.name + ) # extract the unquoted location value from the LocationProperty + + assert final_location == expected_location + + if table_properties is not None: + assert "location" not in table_properties + + +def test_create_schema(adapter: AthenaEngineAdapter) -> None: + adapter.create_schema("test") + + adapter.s3_warehouse_location = "s3://base" + adapter.create_schema("test") + + assert to_sql_calls(adapter) == [ + "CREATE SCHEMA IF NOT EXISTS `test`", + "CREATE SCHEMA IF NOT EXISTS `test` LOCATION 's3://base/test/'", + ] + + +def test_create_table_hive(adapter: AthenaEngineAdapter) -> None: + expressions = d.parse( + """ + MODEL ( + name test_table, + kind FULL, + partitioned_by (cola, colb), + storage_format parquet, + physical_properties ( + s3_base_location = 's3://foo', + has_encrypted_data = 'true' + ) + ); + + SELECT 1::timestamp AS cola, 2::varchar as colb, 'foo' as colc; + """ + ) + model: SqlModel = t.cast(SqlModel, load_sql_based_model(expressions)) + + adapter.create_table( + model.name, + columns_to_types=model.columns_to_types_or_raise, + table_properties=model.physical_properties, + partitioned_by=model.partitioned_by, + storage_format=model.storage_format, + ) + + assert to_sql_calls(adapter) == [ + "CREATE EXTERNAL TABLE IF NOT EXISTS `test_table` (`colc` STRING) PARTITIONED BY (`cola` TIMESTAMP, `colb` STRING) STORED AS PARQUET LOCATION 's3://foo/test_table/' TBLPROPERTIES ('has_encrypted_data'='true')" + ] + + +def test_create_table_iceberg(adapter: AthenaEngineAdapter) -> None: + expressions = d.parse( + """ + MODEL ( + name test_table, + kind FULL, + partitioned_by (colc, bucket(16, cola)), + storage_format parquet, + physical_properties ( + table_type = 'iceberg', + s3_base_location = 's3://foo' + ) + ); + + SELECT 1::timestamp AS cola, 2::varchar as colb, 'foo' as colc; + """ + ) + model: SqlModel = t.cast(SqlModel, load_sql_based_model(expressions)) + + adapter.create_table( + model.name, + columns_to_types=model.columns_to_types_or_raise, + table_properties=model.physical_properties, + partitioned_by=model.partitioned_by, + storage_format=model.storage_format, + ) + + assert to_sql_calls(adapter) == [ + "CREATE TABLE IF NOT EXISTS `test_table` (`cola` TIMESTAMP, `colb` STRING, `colc` STRING) PARTITIONED BY (`colc`, BUCKET(16, `cola`)) LOCATION 's3://foo/test_table/' TBLPROPERTIES ('table_type'='iceberg', 'format'='parquet')" + ] + + +def test_create_table_inferred_location(adapter: AthenaEngineAdapter) -> None: + expressions = d.parse( + """ + MODEL ( + name test_table, + kind FULL + ); + + SELECT a::int FROM foo; + """ + ) + model: SqlModel = t.cast(SqlModel, load_sql_based_model(expressions)) + + adapter.create_table( + model.name, + columns_to_types=model.columns_to_types_or_raise, + table_properties=model.physical_properties, + ) + + adapter.s3_warehouse_location = "s3://bucket/prefix" + adapter.create_table( + model.name, + columns_to_types=model.columns_to_types_or_raise, + table_properties=model.physical_properties, + ) + + assert to_sql_calls(adapter) == [ + "CREATE EXTERNAL TABLE IF NOT EXISTS `test_table` (`a` INT)", + "CREATE EXTERNAL TABLE IF NOT EXISTS `test_table` (`a` INT) LOCATION 's3://bucket/prefix/test_table/'", + ] + + +def test_ctas_hive(adapter: AthenaEngineAdapter): + adapter.s3_warehouse_location = "s3://bucket/prefix/" + + adapter.ctas( + table_name="foo.bar", + columns_to_types={"a": exp.DataType.build("int")}, + query_or_df=parse_one("select 1", into=exp.Select), + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "foo"."bar" WITH (external_location=\'s3://bucket/prefix/foo/bar/\') AS SELECT CAST("a" AS INTEGER) AS "a" FROM (SELECT 1) AS "_subquery"' + ] + + +def test_ctas_iceberg(adapter: AthenaEngineAdapter): + adapter.s3_warehouse_location = "s3://bucket/prefix/" + + adapter.ctas( + table_name="foo.bar", + columns_to_types={"a": exp.DataType.build("int")}, + query_or_df=parse_one("select 1", into=exp.Select), + table_properties={"table_type": exp.Literal.string("iceberg")}, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "foo"."bar" WITH (location=\'s3://bucket/prefix/foo/bar/\', table_type=\'iceberg\', is_external=false) AS SELECT CAST("a" AS INTEGER) AS "a" FROM (SELECT 1) AS "_subquery"' + ] + + +def test_replace_query(adapter: AthenaEngineAdapter, mocker: MockerFixture): + mocker.patch( + "sqlmesh.core.engine_adapter.athena.AthenaEngineAdapter.table_exists", return_value=True + ) + mocker.patch( + "sqlmesh.core.engine_adapter.athena.AthenaEngineAdapter._query_table_type", + return_value="iceberg", + ) + + adapter.replace_query( + table_name="test", + query_or_df=parse_one("select 1 as a", into=exp.Select), + columns_to_types={"a": exp.DataType.build("int")}, + table_properties={}, + ) + + assert to_sql_calls(adapter) == [ + 'DELETE FROM "test" WHERE TRUE', + 'INSERT INTO "test" ("a") SELECT 1 AS "a"', + ] + + mocker.patch( + "sqlmesh.core.engine_adapter.athena.AthenaEngineAdapter.table_exists", return_value=False + ) + adapter.cursor.execute.reset_mock() + + adapter.replace_query( + table_name="test", + query_or_df=parse_one("select 1 as a", into=exp.Select), + columns_to_types={"a": exp.DataType.build("int")}, + table_properties={}, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "test" AS SELECT CAST("a" AS INTEGER) AS "a" FROM (SELECT 1 AS "a") AS "_subquery"' + ] + + +def test_columns(adapter: AthenaEngineAdapter, mocker: MockerFixture): + mock = mocker.patch( + "pandas.io.sql.read_sql_query", + return_value=pd.DataFrame( + data=[["col1", "int"], ["col2", "varchar"]], columns=["column_name", "data_type"] + ), + ) + + assert adapter.columns("foo.bar") == { + "col1": exp.DataType.build("int"), + "col2": exp.DataType.build("varchar"), + } + + assert ( + mock.call_args_list[0][0][0] + == """SELECT "column_name", "data_type" FROM "information_schema"."columns" WHERE "table_schema" = 'foo' AND "table_name" = 'bar' ORDER BY "ordinal_position" NULLS FIRST""" + ) + + +def test_truncate_table(adapter: AthenaEngineAdapter): + adapter._truncate_table(exp.to_table("foo.bar")) + + assert to_sql_calls(adapter) == ['DELETE FROM "foo"."bar"'] + + +def test_create_state_table(adapter: AthenaEngineAdapter): + adapter.create_state_table("_snapshots", {"name": exp.DataType.build("varchar")}) + + assert to_sql_calls(adapter) == [ + "CREATE TABLE IF NOT EXISTS `_snapshots` (`name` STRING) TBLPROPERTIES ('table_type'='iceberg')" + ] diff --git a/tests/core/engine_adapter/test_integration.py b/tests/core/engine_adapter/test_integration.py index 5d0cce43e..9d6cff9cd 100644 --- a/tests/core/engine_adapter/test_integration.py +++ b/tests/core/engine_adapter/test_integration.py @@ -16,6 +16,7 @@ from sqlmesh import Config, Context, EngineAdapter from sqlmesh.cli.example_project import init_example_project from sqlmesh.core.config import load_config_from_paths +from sqlmesh.core.config.connection import AthenaConnectionConfig from sqlmesh.core.dialect import normalize_model_name import sqlmesh.core.dialect as d from sqlmesh.core.engine_adapter import SparkEngineAdapter, TrinoEngineAdapter @@ -180,6 +181,13 @@ def table(self, table_name: str, schema: str = TEST_SCHEMA) -> exp.Table: ) ) + def physical_properties( + self, properties_for_dialect: t.Dict[str, t.Dict[str, str | exp.Expression]] + ) -> t.Dict[str, exp.Expression]: + if props := properties_for_dialect.get(self.dialect): + return {k: exp.Literal.string(v) if isinstance(v, str) else v for k, v in props.items()} + return {} + def schema(self, schema_name: str, catalog_name: t.Optional[str] = None) -> str: return exp.table_name( normalize_model_name( @@ -687,6 +695,15 @@ def config() -> Config: pytest.mark.xdist_group("engine_integration_clickhouse_cloud"), ], ), + pytest.param( + "athena", + marks=[ + pytest.mark.engine, + pytest.mark.remote, + pytest.mark.athena, + pytest.mark.xdist_group("engine_integration_athena"), + ], + ), ] ) def mark_gateway(request) -> t.Tuple[str, str]: @@ -694,12 +711,28 @@ def mark_gateway(request) -> t.Tuple[str, str]: @pytest.fixture -def engine_adapter(mark_gateway: t.Tuple[str, str], config) -> EngineAdapter: +def engine_adapter(mark_gateway: t.Tuple[str, str], config, testrun_uid) -> EngineAdapter: mark, gateway = mark_gateway if gateway not in config.gateways: # TODO: Once everything is fully setup we want to error if a gateway is not configured that we expect pytest.skip(f"Gateway {gateway} not configured") connection_config = config.gateways[gateway].connection + + if mark == "athena": + connection_config = t.cast(AthenaConnectionConfig, connection_config) + # S3 files need to go into a unique location for each test run + # This is because DROP TABLE on a Hive table just drops the table from the metastore + # The files still exist in S3, so if you CREATE TABLE to the same location, the old data shows back up + # This is a problem for any tests like `test_init_project` that use a consistent schema like `sqlmesh_example` between runs + # Note that the `testrun_uid` fixture comes from the xdist plugin + testrun_path = f"testrun_{testrun_uid}" + if current_location := connection_config.s3_warehouse_location: + if testrun_path not in current_location: + # only add it if its not already there (since this setup code gets called multiple times in a full test run) + connection_config.s3_warehouse_location = os.path.join( + current_location, testrun_path + ) + engine_adapter = connection_config.create_engine_adapter() # Trino: If we batch up the requests then when running locally we get a table not found error after creating the # table and then immediately after trying to insert rows into it. There seems to be a delay between when the @@ -1276,7 +1309,16 @@ def test_merge(ctx: TestContext): ctx.init() table = ctx.table("test_table") - ctx.engine_adapter.create_table(table, ctx.columns_to_types) + + table_properties = ctx.physical_properties( + { + # Athena only supports MERGE on Iceberg tables + # And it cant fall back to a logical merge on Hive tables because it cant delete records + "athena": {"table_type": "iceberg"} + } + ) + + ctx.engine_adapter.create_table(table, ctx.columns_to_types, table_properties=table_properties) input_data = pd.DataFrame( [ {"id": 1, "ds": "2022-01-01"}, @@ -1346,7 +1388,13 @@ def test_scd_type_2_by_time(ctx: TestContext): input_schema = { k: v for k, v in ctx.columns_to_types.items() if k not in ("valid_from", "valid_to") } - ctx.engine_adapter.create_table(table, ctx.columns_to_types) + table_properties = ctx.physical_properties( + { + # Athena only supports the operations required for SCD models on Iceberg tables + "athena": {"table_type": "iceberg"} + } + ) + ctx.engine_adapter.create_table(table, ctx.columns_to_types, table_properties=table_properties) input_data = pd.DataFrame( [ {"id": 1, "name": "a", "updated_at": "2022-01-01 00:00:00"}, @@ -1364,6 +1412,7 @@ def test_scd_type_2_by_time(ctx: TestContext): execution_time="2023-01-01 00:00:00", updated_at_as_valid_from=False, columns_to_types=input_schema, + table_properties=table_properties, ) results = ctx.get_metadata_results() assert len(results.views) == 0 @@ -1424,6 +1473,7 @@ def test_scd_type_2_by_time(ctx: TestContext): execution_time="2023-01-05 00:00:00", updated_at_as_valid_from=False, columns_to_types=input_schema, + table_properties=table_properties, ) results = ctx.get_metadata_results() assert len(results.views) == 0 @@ -1489,7 +1539,13 @@ def test_scd_type_2_by_column(ctx: TestContext): input_schema = { k: v for k, v in ctx.columns_to_types.items() if k not in ("valid_from", "valid_to") } - ctx.engine_adapter.create_table(table, ctx.columns_to_types) + table_properties = ctx.physical_properties( + { + # Athena only supports the operations required for SCD models on Iceberg tables + "athena": {"table_type": "iceberg"} + } + ) + ctx.engine_adapter.create_table(table, ctx.columns_to_types, table_properties=table_properties) input_data = pd.DataFrame( [ {"id": 1, "name": "a", "status": "active"}, @@ -1721,7 +1777,15 @@ def test_truncate_table(ctx: TestContext): ctx.init() table = ctx.table("test_table") - ctx.engine_adapter.create_table(table, ctx.columns_to_types) + + table_properties = ctx.physical_properties( + { + # Athena only supports TRUNCATE (DELETE FROM
) on Iceberg tables + "athena": {"table_type": "iceberg"} + } + ) + + ctx.engine_adapter.create_table(table, ctx.columns_to_types, table_properties=table_properties) input_data = pd.DataFrame( [ {"id": 1, "ds": "2022-01-01"}, @@ -1773,7 +1837,13 @@ def test_sushi(mark_gateway: t.Tuple[str, str], ctx: TestContext): ], personal_paths=[pathlib.Path("~/.sqlmesh/config.yaml").expanduser()], ) - _, gateway = mark_gateway + mark, gateway = mark_gateway + + if mark == "athena": + # Ensure that this test is using the same s3_warehouse_location as TestContext (which includes the testrun_id) + config.gateways[ + gateway + ].connection.s3_warehouse_location = ctx.engine_adapter.s3_warehouse_location # clear cache from prior runs cache_dir = pathlib.Path("./examples/sushi/.cache") @@ -1854,6 +1924,13 @@ def test_sushi(mark_gateway: t.Tuple[str, str], ctx: TestContext): "CREATE VIEW raw.demographics ON CLUSTER cluster1 AS SELECT 1 AS customer_id, '00000' AS zip;" ) + # Athena needs models that get mutated after creation to be using Iceberg + if ctx.dialect == "athena": + for model_name in {"sushi.customer_revenue_lifetime"}: + context.get_model(model_name).physical_properties["table_type"] = exp.Literal.string( + "iceberg" + ) + plan: Plan = context.plan( environment="test_prod", start=start, @@ -2136,7 +2213,14 @@ def _normalize_snowflake(name: str, prefix_regex: str = "(sqlmesh__)(.*)"): if config.model_defaults.dialect != ctx.dialect: config.model_defaults = config.model_defaults.copy(update={"dialect": ctx.dialect}) - _, gateway = mark_gateway + mark, gateway = mark_gateway + + if mark == "athena": + # Ensure that this test is using the same s3_warehouse_location as TestContext (which includes the testrun_id) + config.gateways[ + gateway + ].connection.s3_warehouse_location = ctx.engine_adapter.s3_warehouse_location + context = Context(paths=tmp_path, config=config, gateway=gateway) ctx.engine_adapter = context.engine_adapter @@ -2368,6 +2452,12 @@ def _mutate_config(current_gateway_name: str, config: Config): context.upsert_model( create_sql_model(name=f"{schema}.seed_model", query=seed_query, kind="FULL") ) + + physical_properties = "" + if ctx.dialect == "athena": + # INCREMENTAL_BY_UNIQUE_KEY uses MERGE which is only supported in Athena on Iceberg tables + physical_properties = "physical_properties (table_type = 'iceberg')," + context.upsert_model( load_sql_based_model( d.parse( @@ -2377,6 +2467,7 @@ def _mutate_config(current_gateway_name: str, config: Config): unique_key item_id, batch_size 1 ), + {physical_properties} start '2020-01-01', end '2020-01-07', cron '@daily' From 957cb97cbd21792497e4ffabb3ff95e4a4437dfd Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Fri, 20 Sep 2024 04:04:28 +0000 Subject: [PATCH 2/2] PR feedback --- .gitignore | 1 + Makefile | 2 +- docs/integrations/engines/athena.md | 23 ++++---- sqlmesh/core/engine_adapter/athena.py | 48 +++++++--------- tests/core/engine_adapter/test_athena.py | 57 ++++++++++++------- tests/core/engine_adapter/test_integration.py | 21 ++++++- 6 files changed, 89 insertions(+), 63 deletions(-) diff --git a/.gitignore b/.gitignore index 64bef2bed..f4b62e829 100644 --- a/.gitignore +++ b/.gitignore @@ -108,6 +108,7 @@ venv/ ENV/ env.bak/ venv.bak/ +venv*/ # Spyder project settings .spyderproject diff --git a/Makefile b/Makefile index 21c921397..a62080604 100644 --- a/Makefile +++ b/Makefile @@ -211,4 +211,4 @@ clickhouse-cloud-test: guard-CLICKHOUSE_CLOUD_HOST guard-CLICKHOUSE_CLOUD_USERNA pytest -n auto -x -m "clickhouse_cloud" --retries 3 --junitxml=test-results/junit-clickhouse-cloud.xml athena-test: guard-AWS_ACCESS_KEY_ID guard-AWS_SECRET_ACCESS_KEY guard-ATHENA_S3_WAREHOUSE_LOCATION engine-athena-install - pytest -n auto -x -m "athena" --retries 3 --junitxml=test-results/junit-athena.xml \ No newline at end of file + pytest -n auto -x -m "athena" --retries 3 --retry-delay 10 --junitxml=test-results/junit-athena.xml \ No newline at end of file diff --git a/docs/integrations/engines/athena.md b/docs/integrations/engines/athena.md index f3efe6137..fbf4fef77 100644 --- a/docs/integrations/engines/athena.md +++ b/docs/integrations/engines/athena.md @@ -30,9 +30,9 @@ Note that PyAthena uses [boto3](https://boto3.amazonaws.com/v1/documentation/api These options are specific to SQLMesh itself and are not passed to PyAthena -| Option | Description | Type | Required | -|-------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------|----------| -| `s3_warehouse_location` | Set the base path in S3 where SQLMesh will place table data. Only required if the schemas dont have default locations set or you arent specifying the location in the model. See [S3 Locations](#s3-locations) below. | string | N | +| Option | Description | Type | Required | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------|----------| +| `s3_warehouse_location` | Set the base path in S3 where SQLMesh will instruct Athena to place table data. Only required if you arent specifying the location in the model itself. See [S3 Locations](#s3-locations) below. | string | N | ## Model properties @@ -40,31 +40,28 @@ The Athena adapter recognises the following model [physical_properties](../../co | Name | Description | Type | Default | |-------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------|---------| -| `s3_base_location`| `s3://` base URI of where the snapshot tables for this model should be located. Overrides `s3_warehouse_location` if one is configured. | string | | +| `s3_base_location`| `s3://` base URI of where the snapshot tables for this model should be written. Overrides `s3_warehouse_location` if one is configured. | string | | | `table_type` | Sets the [table_type](https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties) Athena uses when creating the table. Valid values are `hive` or `iceberg`. | string | `hive` | ## S3 Locations When creating tables, Athena needs to know where in S3 the table data is located. You cannot issue a `CREATE TABLE` statement without specifying a `LOCATION` for the table data. -If the schema you're creating the table under had a `LOCATION` set when it was created, Athena places the table in this location. Otherwise, it throws an error. +In addition, unlike other engines such as Trino, Athena will not infer a table location if you set a _schema_ location via `CREATE SCHEMA LOCATION 's3://schema/location'`. -Therefore, in order for SQLMesh to issue correct `CREATE TABLE` statements to Athena, there are a few strategies you can use to ensure the Athena tables are pointed to the correct S3 locations: +Therefore, in order for SQLMesh to issue correct `CREATE TABLE` statements to Athena, you need to configure where the tables should be stored. There are two options for this: -- Manually pre-create the `sqlmesh__` physical schemas via `CREATE SCHEMA LOCATION 's3://base/location'`. Then when SQLMesh issues `CREATE TABLE` statements for tables within that schema, Athena knows where the data should go -- Set `s3_warehouse_location` in the connection config. SQLMesh will set the table `LOCATION` to be `//` when it issues a `CREATE TABLE` statement -- Set `s3_base_location` in the model `physical_properties`. SQLMesh will set the table `LOCATION` to be `/`. This takes precedence over the `s3_warehouse_location` set in the connection config or the `LOCATION` property on the target schema +- **Project-wide:** set `s3_warehouse_location` in the connection config. SQLMesh will set the table `LOCATION` to be `//` when it creates a snapshot of your model. +- **Per-model:** set `s3_base_location` in the model `physical_properties`. SQLMesh will set the table `LOCATION` to be `/` every time it creates a snapshot of your model. This takes precedence over any `s3_warehouse_location` set in the connection config. -Note that if you opt to pre-create the schemas with a `LOCATION` already configured, you might want to look at [physical_schema_mapping](../../guides/configuration.md#physical-table-schemas) for better control of the schema names. ## Limitations Athena was initially designed to read data stored in S3 and to do so without changing that data. This means that it does not have good support for mutating tables. In particular, it will not delete data from Hive tables. -Consequently, any SQLMesh model types that needs to delete or merge data from existing tables will not work. In addition, [forward only changes](../../concepts/plans.md#forward-only-change) that mutate the schemas of existing tables have a high chance of failure because Athena supports very limited schema modifications on Hive tables. +Consequently, any SQLMesh model types that needs to delete or merge data from existing tables will not work on Hive tables. In addition, [forward only changes](../../concepts/plans.md#forward-only-change) that mutate the schemas of existing tables have a high chance of failure because Athena supports very limited schema modifications on Hive tables. However, Athena does support [Apache Iceberg](https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg.html) tables which allow a full range of operations. These can be used for more complex model types such as [`INCREMENTAL_BY_UNIQUE_KEY`](../../concepts/models/model_kinds.md#incremental_by_unique_key) and [`SCD_TYPE_2`](../../concepts/models/model_kinds.md#scd-type-2). To use an Iceberg table for a model, set `table_type='iceberg'` in the model [physical_properties](../../concepts/models/overview.md#physical_properties). -In general, Iceberg tables offer the most flexibility and you'll run into the least SQLMesh limitations when using them. -However, they're a newer feature of Athena so you may run into Athena limitations that arent present in Hive tables, [particularly around supported data types](https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-supported-data-types.html). \ No newline at end of file +In general, Iceberg tables offer the most flexibility and you'll run into the least SQLMesh limitations when using them. However, we create Hive tables by default because Athena creates Hive tables by default, so Iceberg tables are opt-in rather than opt-out. \ No newline at end of file diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py index 3b3a3d96f..359631da1 100644 --- a/sqlmesh/core/engine_adapter/athena.py +++ b/sqlmesh/core/engine_adapter/athena.py @@ -116,8 +116,7 @@ def _get_data_objects( exp.select( exp.case() .when( - # 'awsdatacatalog' is the default catalog that is invisible for all intents and purposes - # it just happens to show up in information_schema queries + # calling code expects data objects in the default catalog to have their catalog set to None exp.column("table_catalog", table="t").eq("awsdatacatalog"), exp.Null(), ) @@ -134,11 +133,7 @@ def _get_data_objects( .as_("type"), ) .from_(exp.to_table("information_schema.tables", alias="t")) - .where( - exp.and_( - exp.column("table_schema", table="t").eq(schema), - ) - ) + .where(exp.column("table_schema", table="t").eq(schema)) ) if object_names: query = query.where(exp.column("table_name", table="t").isin(*object_names)) @@ -312,20 +307,18 @@ def _build_table_properties_exp( # STORED AS PARQUET properties.append(exp.FileFormatProperty(this=storage_format)) - if table and (location := self._table_location(table_properties, table)): + if table and (location := self._table_location_or_raise(table_properties, table)): properties.append(location) + if is_iceberg and expression: + # To make a CTAS expression persist as iceberg, alongside setting `table_type=iceberg`, you also need to set is_external=false + # Note that SQLGlot does the right thing with LocationProperty and writes it as `location` (Iceberg) instead of `external_location` (Hive) + # ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties + properties.append(exp.Property(this=exp.var("is_external"), value="false")) + for name, value in table_properties.items(): properties.append(exp.Property(this=exp.var(name), value=value)) - if is_iceberg and expression: - # To make a CTAS expression persist as iceberg, alongside setting `table_type=iceberg` (which the user has already - # supplied in physical_properties and is thus set above), you also need to set: - # - is_external=false - # - table_location='s3://' - # ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties - properties.append(exp.Property(this=exp.var("is_external"), value="false")) - if properties: return exp.Properties(expressions=properties) @@ -342,7 +335,7 @@ def _table_type( """ Use the user-specified table_properties to figure out of this is a Hive or an Iceberg table """ - # if table_type is not defined or is not set to "iceberg", this is a Hive table + # if we cant detect any indication of Iceberg, this is a Hive table if table_properties and (table_type := table_properties.get("table_type", None)): if "iceberg" in table_type.sql(dialect=self.dialect).lower(): return "iceberg" @@ -365,6 +358,16 @@ def _query_table_type( return "hive" return "iceberg" + def _table_location_or_raise( + self, table_properties: t.Optional[t.Dict[str, exp.Expression]], table: exp.Table + ) -> exp.LocationProperty: + location = self._table_location(table_properties, table) + if not location: + raise SQLMeshError( + f"Cannot figure out location for table {table}. Please either set `s3_base_location` in `physical_properties` or set `s3_warehouse_location` in the Athena connection config" + ) + return location + def _table_location( self, table_properties: t.Optional[t.Dict[str, exp.Expression]], @@ -384,18 +387,11 @@ def _table_location( elif self.s3_warehouse_location: # If the user has set `s3_warehouse_location` in the connection config, the base URI is /// - catalog_name = table.catalog if hasattr(table, "catalog") else None - schema_name = table.db if hasattr(table, "db") else None - base_uri = os.path.join( - self.s3_warehouse_location, catalog_name or "", schema_name or "" - ) + base_uri = os.path.join(self.s3_warehouse_location, table.catalog or "", table.db or "") else: - # Assume the user has set a default location for this schema in the metastore return None - table_name = table.name if hasattr(table, "name") else None - full_uri = _ensure_valid_location(os.path.join(base_uri, table_name or "")) - + full_uri = _ensure_valid_location(os.path.join(base_uri, table.text("this") or "")) return exp.LocationProperty(this=exp.Literal.string(full_uri)) def _find_matching_columns( diff --git a/tests/core/engine_adapter/test_athena.py b/tests/core/engine_adapter/test_athena.py index ed9785158..a642cf642 100644 --- a/tests/core/engine_adapter/test_athena.py +++ b/tests/core/engine_adapter/test_athena.py @@ -8,6 +8,7 @@ from sqlmesh.core.engine_adapter import AthenaEngineAdapter from sqlmesh.core.model import load_sql_based_model from sqlmesh.core.model.definition import SqlModel +from sqlmesh.utils.errors import SQLMeshError from tests.core.engine_adapter import to_sql_calls @@ -31,7 +32,6 @@ def adapter(make_mocked_engine_adapter: t.Callable) -> AthenaEngineAdapter: exp.to_table("schema.table"), "s3://some/location/table/", ), - (None, None, exp.Table(db=exp.Identifier(this="test")), None), # Location set to bucket ("s3://bucket", None, exp.to_table("schema.table"), "s3://bucket/schema/table/"), ("s3://bucket", {}, exp.to_table("schema.table"), "s3://bucket/schema/table/"), @@ -73,18 +73,18 @@ def test_table_location( expected_location: t.Optional[str], ) -> None: adapter.s3_warehouse_location = config_s3_warehouse_location - location = adapter._table_location(table_properties, table) - final_location = None - - if location and expected_location: - final_location = ( - location.this.name - ) # extract the unquoted location value from the LocationProperty - - assert final_location == expected_location + if expected_location is None: + with pytest.raises(SQLMeshError, match=r"Cannot figure out location for table.*"): + adapter._table_location_or_raise(table_properties, table) + else: + location = adapter._table_location_or_raise( + table_properties, table + ).this.name # extract the unquoted location value from the LocationProperty + assert location == expected_location if table_properties is not None: - assert "location" not in table_properties + # this get consumed by _table_location because we dont want it to end up in a TBLPROPERTIES clause + assert "s3_base_location" not in table_properties def test_create_schema(adapter: AthenaEngineAdapter) -> None: @@ -163,7 +163,7 @@ def test_create_table_iceberg(adapter: AthenaEngineAdapter) -> None: ] -def test_create_table_inferred_location(adapter: AthenaEngineAdapter) -> None: +def test_create_table_no_location(adapter: AthenaEngineAdapter) -> None: expressions = d.parse( """ MODEL ( @@ -176,11 +176,12 @@ def test_create_table_inferred_location(adapter: AthenaEngineAdapter) -> None: ) model: SqlModel = t.cast(SqlModel, load_sql_based_model(expressions)) - adapter.create_table( - model.name, - columns_to_types=model.columns_to_types_or_raise, - table_properties=model.physical_properties, - ) + with pytest.raises(SQLMeshError, match=r"Cannot figure out location.*"): + adapter.create_table( + model.name, + columns_to_types=model.columns_to_types_or_raise, + table_properties=model.physical_properties, + ) adapter.s3_warehouse_location = "s3://bucket/prefix" adapter.create_table( @@ -190,7 +191,6 @@ def test_create_table_inferred_location(adapter: AthenaEngineAdapter) -> None: ) assert to_sql_calls(adapter) == [ - "CREATE EXTERNAL TABLE IF NOT EXISTS `test_table` (`a` INT)", "CREATE EXTERNAL TABLE IF NOT EXISTS `test_table` (`a` INT) LOCATION 's3://bucket/prefix/test_table/'", ] @@ -220,10 +220,22 @@ def test_ctas_iceberg(adapter: AthenaEngineAdapter): ) assert to_sql_calls(adapter) == [ - 'CREATE TABLE IF NOT EXISTS "foo"."bar" WITH (location=\'s3://bucket/prefix/foo/bar/\', table_type=\'iceberg\', is_external=false) AS SELECT CAST("a" AS INTEGER) AS "a" FROM (SELECT 1) AS "_subquery"' + 'CREATE TABLE IF NOT EXISTS "foo"."bar" WITH (location=\'s3://bucket/prefix/foo/bar/\', is_external=false, table_type=\'iceberg\') AS SELECT CAST("a" AS INTEGER) AS "a" FROM (SELECT 1) AS "_subquery"' ] +def test_ctas_iceberg_no_specific_location(adapter: AthenaEngineAdapter): + with pytest.raises(SQLMeshError, match=r"Cannot figure out location.*"): + adapter.ctas( + table_name="foo.bar", + columns_to_types={"a": exp.DataType.build("int")}, + query_or_df=parse_one("select 1", into=exp.Select), + table_properties={"table_type": exp.Literal.string("iceberg")}, + ) + + assert to_sql_calls(adapter) == [] + + def test_replace_query(adapter: AthenaEngineAdapter, mocker: MockerFixture): mocker.patch( "sqlmesh.core.engine_adapter.athena.AthenaEngineAdapter.table_exists", return_value=True @@ -250,6 +262,7 @@ def test_replace_query(adapter: AthenaEngineAdapter, mocker: MockerFixture): ) adapter.cursor.execute.reset_mock() + adapter.s3_warehouse_location = "s3://foo" adapter.replace_query( table_name="test", query_or_df=parse_one("select 1 as a", into=exp.Select), @@ -257,8 +270,9 @@ def test_replace_query(adapter: AthenaEngineAdapter, mocker: MockerFixture): table_properties={}, ) + # gets recreated as a Hive table because table_exists=False and nothing in the properties indicates it should be Iceberg assert to_sql_calls(adapter) == [ - 'CREATE TABLE IF NOT EXISTS "test" AS SELECT CAST("a" AS INTEGER) AS "a" FROM (SELECT 1 AS "a") AS "_subquery"' + 'CREATE TABLE IF NOT EXISTS "test" WITH (external_location=\'s3://foo/test/\') AS SELECT CAST("a" AS INTEGER) AS "a" FROM (SELECT 1 AS "a") AS "_subquery"' ] @@ -288,8 +302,9 @@ def test_truncate_table(adapter: AthenaEngineAdapter): def test_create_state_table(adapter: AthenaEngineAdapter): + adapter.s3_warehouse_location = "s3://base" adapter.create_state_table("_snapshots", {"name": exp.DataType.build("varchar")}) assert to_sql_calls(adapter) == [ - "CREATE TABLE IF NOT EXISTS `_snapshots` (`name` STRING) TBLPROPERTIES ('table_type'='iceberg')" + "CREATE TABLE IF NOT EXISTS `_snapshots` (`name` STRING) LOCATION 's3://base/_snapshots/' TBLPROPERTIES ('table_type'='iceberg')" ] diff --git a/tests/core/engine_adapter/test_integration.py b/tests/core/engine_adapter/test_integration.py index 9d6cff9cd..3ef5e606b 100644 --- a/tests/core/engine_adapter/test_integration.py +++ b/tests/core/engine_adapter/test_integration.py @@ -422,6 +422,13 @@ def create_context( ) if config_mutator: config_mutator(self.gateway, config) + + if "athena" in self.gateway: + # Ensure that s3_warehouse_location is propagated + config.gateways[ + self.gateway + ].connection.s3_warehouse_location = self.engine_adapter.s3_warehouse_location + self._context = Context(paths=".", config=config, gateway=self.gateway) return self._context @@ -710,8 +717,18 @@ def mark_gateway(request) -> t.Tuple[str, str]: return request.param, f"inttest_{request.param}" +@pytest.fixture(scope="session") +def run_count(request) -> t.Iterable[int]: + count: int = request.config.cache.get("run_count", 0) + count += 1 + yield count + request.config.cache.set("run_count", count) + + @pytest.fixture -def engine_adapter(mark_gateway: t.Tuple[str, str], config, testrun_uid) -> EngineAdapter: +def engine_adapter( + mark_gateway: t.Tuple[str, str], config, testrun_uid, run_count +) -> EngineAdapter: mark, gateway = mark_gateway if gateway not in config.gateways: # TODO: Once everything is fully setup we want to error if a gateway is not configured that we expect @@ -730,7 +747,7 @@ def engine_adapter(mark_gateway: t.Tuple[str, str], config, testrun_uid) -> Engi if testrun_path not in current_location: # only add it if its not already there (since this setup code gets called multiple times in a full test run) connection_config.s3_warehouse_location = os.path.join( - current_location, testrun_path + current_location, testrun_path, str(run_count) ) engine_adapter = connection_config.create_engine_adapter()