Skip to content

Commit

Permalink
Feat!: finalize clickhouse engine adapter (#3125)
Browse files Browse the repository at this point in the history
  • Loading branch information
treysp authored Sep 18, 2024
1 parent 2f12ad6 commit fbf941b
Show file tree
Hide file tree
Showing 15 changed files with 1,038 additions and 138 deletions.
1 change: 1 addition & 0 deletions docs/guides/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ Other state engines with fast and reliable database transactions (less tested th

Unsupported state engines, even for development:

* [Clickhouse](../integrations/engines/clickhouse.md)
* [Spark](../integrations/engines/spark.md)
* [Trino](../integrations/engines/trino.md)

Expand Down
295 changes: 295 additions & 0 deletions docs/integrations/engines/clickhouse.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ nav:
- integrations/github.md
- Execution engines:
- integrations/engines/bigquery.md
- integrations/engines/clickhouse.md
- integrations/engines/databricks.md
- integrations/engines/duckdb.md
- integrations/engines/motherduck.md
Expand Down Expand Up @@ -139,6 +140,7 @@ markdown_extensions:
- pymdownx.details
- attr_list
- md_in_html
- pymdownx.caret
extra_css:
- stylesheets/extra.css
copyright: Tobiko Data Inc.
Expand Down
30 changes: 27 additions & 3 deletions sqlmesh/core/config/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from pydantic import Field
from sqlglot import exp
from sqlglot.helper import subclasses

from sqlmesh.core import engine_adapter
from sqlmesh.core.config.base import BaseConfig
from sqlmesh.core.config.common import (
Expand Down Expand Up @@ -1394,6 +1393,7 @@ class ClickhouseConnectionConfig(ConnectionConfig):
query_limit: int = 0
use_compression: bool = True
compression_method: t.Optional[str] = None
connection_settings: t.Optional[t.Dict[str, t.Any]] = None

concurrent_tasks: int = 1
register_comments: bool = True
Expand Down Expand Up @@ -1425,9 +1425,13 @@ def _connection_factory(self) -> t.Callable:

return connect

@property
def cloud_mode(self) -> bool:
return "clickhouse.cloud" in self.host

@property
def _extra_engine_config(self) -> t.Dict[str, t.Any]:
return {"cluster": self.cluster}
return {"cluster": self.cluster, "cloud_mode": self.cloud_mode}

@property
def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
Expand All @@ -1440,7 +1444,27 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
if compress and self.compression_method:
compress = self.compression_method

return {"compress": compress, "client_name": f"SQLMesh/{__version__}"}
# Clickhouse system settings passed to connection
# https://clickhouse.com/docs/en/operations/settings/settings
# - below are set to align with dbt-clickhouse
# - https://github.com/ClickHouse/dbt-clickhouse/blob/44d26308ea6a3c8ead25c280164aa88191f05f47/dbt/adapters/clickhouse/dbclient.py#L77
settings = self.connection_settings or {}
# mutations_sync = 2: "The query waits for all mutations [ALTER statements] to complete on all replicas (if they exist)"
settings["mutations_sync"] = "2"
# insert_distributed_sync = 1: "INSERT operation succeeds only after all the data is saved on all shards"
settings["insert_distributed_sync"] = "1"
if self.cluster or self.cloud_mode:
# database_replicated_enforce_synchronous_settings = 1:
# - "Enforces synchronous waiting for some queries"
# - https://github.com/ClickHouse/ClickHouse/blob/ccaa8d03a9351efc16625340268b9caffa8a22ba/src/Core/Settings.h#L709
settings["database_replicated_enforce_synchronous_settings"] = "1"
# insert_quorum = auto:
# - "INSERT succeeds only when ClickHouse manages to correctly write data to the insert_quorum of replicas during
# the insert_quorum_timeout"
# - "use majority number (number_of_replicas / 2 + 1) as quorum number"
settings["insert_quorum"] = "auto"

return {"compress": compress, "client_name": f"SQLMesh/{__version__}", **settings}


CONNECTION_CONFIG_TO_TYPE = {
Expand Down
68 changes: 46 additions & 22 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,14 +1092,12 @@ def drop_view(
**kwargs: t.Any,
) -> None:
"""Drop a view."""
self.execute(
exp.Drop(
this=exp.to_table(view_name),
exists=ignore_if_not_exists,
materialized=materialized and self.SUPPORTS_MATERIALIZED_VIEWS,
kind="VIEW",
**kwargs,
)
self._drop_object(
name=view_name,
exists=ignore_if_not_exists,
kind="VIEW",
materialized=materialized and self.SUPPORTS_MATERIALIZED_VIEWS,
**kwargs,
)

def columns(
Expand Down Expand Up @@ -1466,23 +1464,29 @@ def remove_managed_columns(
if check_columns:
row_check_conditions = []
for col in check_columns:
t_col = col.copy()
col_qualified = col.copy()
col_qualified.set("table", exp.to_identifier("joined"))

t_col = col_qualified.copy()
t_col.this.set("this", f"t_{col.name}")

row_check_conditions.extend(
[
col.neq(t_col),
exp.and_(t_col.is_(exp.Null()), col.is_(exp.Null()).not_()),
exp.and_(t_col.is_(exp.Null()).not_(), col.is_(exp.Null())),
col_qualified.neq(t_col),
exp.and_(t_col.is_(exp.Null()), col_qualified.is_(exp.Null()).not_()),
exp.and_(t_col.is_(exp.Null()).not_(), col_qualified.is_(exp.Null())),
]
)
row_value_check = exp.or_(*row_check_conditions)
unique_key_conditions = []
for key in unique_key:
t_key = key.copy()
key_qualified = key.copy()
key_qualified.set("table", exp.to_identifier("joined"))
t_key = key_qualified.copy()
for col in t_key.find_all(exp.Column):
col.this.set("this", f"t_{col.name}")
unique_key_conditions.extend(
[t_key.is_(exp.Null()).not_(), key.is_(exp.Null()).not_()]
[t_key.is_(exp.Null()).not_(), key_qualified.is_(exp.Null()).not_()]
)
unique_key_check = exp.and_(*unique_key_conditions)
# unique_key_check is saying "if the row is updated"
Expand All @@ -1509,11 +1513,15 @@ def remove_managed_columns(
).as_(valid_from_col.this)
else:
assert updated_at_col is not None
prefixed_updated_at_col = updated_at_col.copy()
prefixed_updated_at_col.this.set("this", f"t_{updated_at_col.name}")
updated_row_filter = updated_at_col > prefixed_updated_at_col

valid_to_case_stmt_builder = exp.Case().when(updated_row_filter, updated_at_col)
updated_at_col_qualified = updated_at_col.copy()
updated_at_col_qualified.set("table", exp.to_identifier("joined"))
prefixed_updated_at_col = updated_at_col_qualified.copy()
prefixed_updated_at_col.this.set("this", f"t_{updated_at_col_qualified.name}")
updated_row_filter = updated_at_col_qualified > prefixed_updated_at_col

valid_to_case_stmt_builder = exp.Case().when(
updated_row_filter, updated_at_col_qualified
)
if delete_check:
valid_to_case_stmt_builder = valid_to_case_stmt_builder.when(
delete_check, execution_ts
Expand Down Expand Up @@ -1573,7 +1581,11 @@ def remove_managed_columns(
"source",
exp.select(exp.true().as_("_exists"), *select_source_columns)
.distinct(*unique_key)
.from_(source_query.subquery("raw_source")), # type: ignore
.from_(
self.use_server_nulls_for_unmatched_after_join(source_query).subquery( # type: ignore
"raw_source"
)
),
)
# Historical Records that Do Not Change
.with_(
Expand Down Expand Up @@ -1714,7 +1726,7 @@ def remove_managed_columns(

self.replace_query(
target_table,
query,
self.ensure_nulls_for_unmatched_after_join(query),
columns_to_types=columns_to_types,
table_description=table_description,
column_descriptions=column_descriptions,
Expand Down Expand Up @@ -2229,7 +2241,7 @@ def _replace_by_key(
delete_filter = key_exp.isin(query=delete_query)

if not self.INSERT_OVERWRITE_STRATEGY.is_replace_where:
self.execute(exp.delete(target_table).where(delete_filter))
self.delete_from(target_table, delete_filter)
else:
insert_statement.set("where", delete_filter)
insert_statement.set("this", exp.to_table(target_table))
Expand Down Expand Up @@ -2293,6 +2305,18 @@ def _rename_table(
) -> None:
self.execute(exp.rename_table(old_table_name, new_table_name))

def ensure_nulls_for_unmatched_after_join(
self,
query: Query,
) -> Query:
return query

def use_server_nulls_for_unmatched_after_join(
self,
query: Query,
) -> Query:
return query

def ping(self) -> None:
try:
self._execute(exp.select("1").sql(dialect=self.dialect))
Expand Down
94 changes: 86 additions & 8 deletions sqlmesh/core/engine_adapter/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
CommentCreationView,
)
from sqlmesh.core.schema_diff import SchemaDiffer
from functools import cached_property

if t.TYPE_CHECKING:
from sqlmesh.core._typing import SchemaName, TableName
Expand All @@ -31,19 +30,17 @@ class ClickhouseEngineAdapter(EngineAdapterWithIndexSupport, LogicalMergeMixin):
DIALECT = "clickhouse"
SUPPORTS_TRANSACTIONS = False
SUPPORTS_VIEW_SCHEMA = False
SUPPORTS_REPLACE_TABLE = False
COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY

SCHEMA_DIFFER = SchemaDiffer()

DEFAULT_TABLE_ENGINE = "MergeTree"
ORDER_BY_TABLE_ENGINE_REGEX = "^.*?MergeTree.*$"

@cached_property
@property
def engine_run_mode(self) -> EngineRunMode:
cloud_query_value = self.fetchone(
"select value from system.settings where name='cloud_mode'"
)
if str(cloud_query_value[0]) == "1":
if self._extra_config.get("cloud_mode"):
return EngineRunMode.CLOUD
# we use the user's specification of a cluster in the connection config to determine if
# the engine is in cluster mode
Expand Down Expand Up @@ -242,6 +239,12 @@ def _rename_table(

self.execute(f"RENAME TABLE {old_table_sql} TO {new_table_sql}{self._on_cluster_sql()}")

def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None:
delete_expr = exp.delete(table_name, where)
if self.engine_run_mode.is_cluster:
delete_expr.set("cluster", exp.OnCluster(this=exp.to_identifier(self.cluster)))
self.execute(delete_expr)

def alter_table(
self,
alter_expressions: t.List[exp.Alter],
Expand Down Expand Up @@ -296,6 +299,77 @@ def _build_partitioned_by_exp(
this=exp.Schema(expressions=partitioned_by),
)

def ensure_nulls_for_unmatched_after_join(
self,
query: Query,
) -> Query:
# Set `join_use_nulls = 1` in a query's SETTINGS clause
query.append("settings", exp.var("join_use_nulls").eq(exp.Literal.number("1")))
return query

def use_server_nulls_for_unmatched_after_join(
self,
query: Query,
) -> Query:
# Set the `join_use_nulls` server value in a query's SETTINGS clause
#
# Use in SCD models:
# - The SCD query we build must include the setting `join_use_nulls = 1` to ensure that empty cells in a join
# are filled with NULL instead of the default data type value. The default join_use_nulls value is `0`.
# - The SCD embeds the user's original query in the `source` CTE
# - Settings are dynamically scoped, so our setting may override the server's default setting the user expects
# for their query.
# - To prevent this, we:
# - If the user query sets `join_use_nulls`, we do nothing
# - If the user query does not set `join_use_nulls`, we query the server for the current setting
# - If the server value is 1, we do nothing
# - If the server values is not 1, we inject its `join_use_nulls` value into the user query
# - We do not need to check user subqueries because our injected setting operates at the same scope the
# server value would normally operate at
setting_name = "join_use_nulls"
setting_value = "1"

user_settings = query.args.get("settings")
# if user has not already set it explicitly
if not (
user_settings
and any(
[
isinstance(setting, exp.EQ) and setting.name == setting_name
for setting in user_settings
]
)
):
server_value = self.fetchone(
exp.select("value")
.from_("system.settings")
.where(exp.column("name").eq(exp.Literal.string(setting_name)))
)[0]
# only inject the setting if the server value isn't 1
inject_setting = setting_value != server_value
setting_value = server_value if inject_setting else setting_value

if inject_setting:
query.append(
"settings", exp.var(setting_name).eq(exp.Literal.number(setting_value))
)

return query

def _build_settings_property(
self, key: str, value: exp.Expression | str | int | float
) -> exp.SettingsProperty:
return exp.SettingsProperty(
expressions=[
exp.EQ(
this=exp.var(key.lower()),
expression=value
if isinstance(value, exp.Expression)
else exp.Literal(this=value, is_string=isinstance(value, str)),
)
]
)

def _build_table_properties_exp(
self,
catalog_name: t.Optional[str] = None,
Expand Down Expand Up @@ -385,7 +459,9 @@ def _build_table_properties_exp(
properties.append(exp.EmptyProperty())

if table_properties_copy:
properties.extend(self._table_or_view_properties_to_expressions(table_properties_copy))
properties.extend(
[self._build_settings_property(k, v) for k, v in table_properties_copy.items()]
)

if table_description:
properties.append(
Expand Down Expand Up @@ -414,7 +490,9 @@ def _build_view_properties_exp(
properties.append(exp.OnCluster(this=exp.to_identifier(self.cluster)))

if view_properties_copy:
properties.extend(self._table_or_view_properties_to_expressions(view_properties_copy))
properties.extend(
[self._build_settings_property(k, v) for k, v in view_properties_copy.items()]
)

if table_description:
properties.append(
Expand Down
2 changes: 0 additions & 2 deletions sqlmesh/core/engine_adapter/trino.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

import typing as t
from functools import lru_cache

import pandas as pd
from pandas.api.types import is_datetime64_any_dtype # type: ignore
from sqlglot import exp
Expand Down
5 changes: 4 additions & 1 deletion sqlmesh/core/model/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def _normalize(value: t.Any) -> t.Any:
return v

@field_validator("storage_format", mode="before")
def _storage_format_validator(cls, v: t.Any) -> t.Optional[str]:
@field_validator_v1_args
def _storage_format_validator(cls, v: t.Any, values: t.Dict[str, t.Any]) -> t.Optional[str]:
if isinstance(v, exp.Expression) and not (isinstance(v, (exp.Literal, exp.Identifier))):
return v.sql(values.get("dialect"))
return str_or_exp_to_str(v)

@field_validator("dialect", mode="before")
Expand Down
6 changes: 4 additions & 2 deletions sqlmesh/core/table_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def name(e: exp.Expression) -> str:
.join(target_table.as_("t"), on=self.on, join_type="FULL")
)

query = (
base_query = (
exp.Select()
.with_(source_table, source_query)
.with_(target_table, target_query)
Expand All @@ -355,7 +355,9 @@ def name(e: exp.Expression) -> str:
.from_(stats_table)
)

query = quote_identifiers(query, dialect=self.model_dialect or self.dialect)
query = self.adapter.ensure_nulls_for_unmatched_after_join(
quote_identifiers(base_query.copy(), dialect=self.model_dialect or self.dialect)
)
temp_table = exp.table_("diff", db="sqlmesh_temp", quoted=True)

with self.adapter.temp_table(query, name=temp_table) as table:
Expand Down
Loading

0 comments on commit fbf941b

Please sign in to comment.