Skip to content

Commit

Permalink
feat: support multiple when matched inc unique key
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq committed Sep 12, 2024
1 parent aa41dee commit bf89546
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 34 deletions.
6 changes: 4 additions & 2 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,9 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]:

name = key.name.lower()
if name == "when_matched":
value: t.Optional[exp.Expression] = self._parse_when_matched()[0]
value: t.Optional[t.Union[exp.Expression, t.List[exp.Expression]]] = (
self._parse_when_matched() # type: ignore
)
elif name == "time_data_type":
# TODO: if we make *_data_type a convention to parse things into exp.DataType, we could make this more generic
value = self._parse_types(schema=True)
Expand All @@ -410,7 +412,7 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]:

if name == "path" and value:
# Make sure if we get a windows path that it is converted to posix
value = exp.Literal.string(value.this.replace("\\", "/"))
value = exp.Literal.string(value.this.replace("\\", "/")) # type: ignore

return self.expression(exp.Property, this=name, value=value)

Expand Down
6 changes: 4 additions & 2 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,7 +1726,7 @@ def merge(
source_table: QueryOrDF,
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
unique_key: t.Sequence[exp.Expression],
when_matched: t.Optional[exp.When] = None,
when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None,
) -> None:
source_queries, columns_to_types = self._get_source_queries_and_columns_to_types(
source_table, columns_to_types, target_table=target_table
Expand All @@ -1749,6 +1749,7 @@ def merge(
],
),
)
when_matched = ensure_list(when_matched)
when_not_matched = exp.When(
matched=False,
source=False,
Expand All @@ -1759,13 +1760,14 @@ def merge(
),
),
)
match_expressions = when_matched + [when_not_matched]
for source_query in source_queries:
with source_query as query:
self._merge(
target_table=target_table,
query=query,
on=on,
match_expressions=[when_matched, when_not_matched],
match_expressions=match_expressions,
)

def rename_table(
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/engine_adapter/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def merge(
source_table: QueryOrDF,
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
unique_key: t.Sequence[exp.Expression],
when_matched: t.Optional[exp.When] = None,
when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None,
) -> None:
"""
Merge implementation for engine adapters that do not support merge natively.
Expand Down
27 changes: 18 additions & 9 deletions sqlmesh/core/model/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from pydantic import Field
from sqlglot import exp
from sqlglot.helper import ensure_list
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import quote_identifiers
from sqlglot.optimizer.simplify import gen
Expand Down Expand Up @@ -424,14 +425,16 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
class IncrementalByUniqueKeyKind(_IncrementalBy):
name: Literal[ModelKindName.INCREMENTAL_BY_UNIQUE_KEY] = ModelKindName.INCREMENTAL_BY_UNIQUE_KEY
unique_key: SQLGlotListOfFields
when_matched: t.Optional[exp.When] = None
when_matched: t.Optional[t.List[exp.When]] = None
batch_concurrency: Literal[1] = 1

@field_validator("when_matched", mode="before")
@field_validator_v1_args
def _when_matched_validator(
cls, v: t.Optional[t.Union[exp.When, str]], values: t.Dict[str, t.Any]
) -> t.Optional[exp.When]:
cls,
v: t.Optional[t.Union[exp.When, str, t.List[exp.When], t.List[str]]],
values: t.Dict[str, t.Any],
) -> t.Optional[t.List[exp.When]]:
def replace_table_references(expression: exp.Expression) -> exp.Expression:
from sqlmesh.core.engine_adapter.base import (
MERGE_SOURCE_ALIAS,
Expand All @@ -451,13 +454,19 @@ def replace_table_references(expression: exp.Expression) -> exp.Expression:
)
return expression

if isinstance(v, str):
return t.cast(exp.When, d.parse_one(v, into=exp.When, dialect=get_dialect(values)))

if not v:
return v

return t.cast(exp.When, v.transform(replace_table_references))
return v # type: ignore

result = []
list_v = ensure_list(v)
for value in ensure_list(list_v):
if isinstance(value, str):
result.append(
t.cast(exp.When, d.parse_one(value, into=exp.When, dialect=get_dialect(values)))
)
else:
result.append(t.cast(exp.When, value.transform(replace_table_references))) # type: ignore
return result

@property
def data_hash_values(self) -> t.List[t.Optional[str]]:
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/model/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def managed_columns(self) -> t.Dict[str, exp.DataType]:
return getattr(self.kind, "managed_columns", {})

@property
def when_matched(self) -> t.Optional[exp.When]:
def when_matched(self) -> t.Optional[t.List[exp.When]]:
if isinstance(self.kind, IncrementalByUniqueKeyKind):
return self.kind.when_matched
return None
Expand Down
69 changes: 69 additions & 0 deletions tests/core/engine_adapter/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,75 @@ def test_merge_when_matched(make_mocked_engine_adapter: t.Callable, assert_exp_e
)


def test_merge_when_matched_multiple(make_mocked_engine_adapter: t.Callable, assert_exp_eq):
adapter = make_mocked_engine_adapter(EngineAdapter)

adapter.merge(
target_table="target",
source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')),
columns_to_types={
"ID": exp.DataType.build("int"),
"ts": exp.DataType.build("timestamp"),
"val": exp.DataType.build("int"),
},
unique_key=[exp.to_identifier("ID", quoted=True)],
when_matched=[
exp.When(
matched=True,
condition=exp.column("ID", "__MERGE_SOURCE__").eq(exp.Literal.number(1)),
then=exp.Update(
expressions=[
exp.column("val", "__MERGE_TARGET__").eq(
exp.column("val", "__MERGE_SOURCE__")
),
exp.column("ts", "__MERGE_TARGET__").eq(
exp.Coalesce(
this=exp.column("ts", "__MERGE_SOURCE__"),
expressions=[exp.column("ts", "__MERGE_TARGET__")],
)
),
],
),
),
exp.When(
matched=True,
source=False,
then=exp.Update(
expressions=[
exp.column("val", "__MERGE_TARGET__").eq(
exp.column("val", "__MERGE_SOURCE__")
),
exp.column("ts", "__MERGE_TARGET__").eq(
exp.Coalesce(
this=exp.column("ts", "__MERGE_SOURCE__"),
expressions=[exp.column("ts", "__MERGE_TARGET__")],
)
),
],
),
),
],
)

assert_exp_eq(
adapter.cursor.execute.call_args[0][0],
"""
MERGE INTO "target" AS "__MERGE_TARGET__" USING (
SELECT
"ID",
"ts",
"val"
FROM "source"
) AS "__MERGE_SOURCE__"
ON "__MERGE_TARGET__"."ID" = "__MERGE_SOURCE__"."ID"
WHEN MATCHED AND "__MERGE_SOURCE__"."ID" = 1 THEN UPDATE SET "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val", "__MERGE_TARGET__"."ts" = COALESCE("__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."ts"),
WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val", "__MERGE_TARGET__"."ts" = COALESCE("__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."ts")
WHEN NOT MATCHED THEN INSERT ("ID", "ts", "val")
VALUES ("__MERGE_SOURCE__"."ID", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")
""",
)


def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
adapter = make_mocked_engine_adapter(EngineAdapter)

Expand Down
68 changes: 65 additions & 3 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3904,10 +3904,44 @@ def test_when_matched():
expected_when_matched = "WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary)"

model = load_sql_based_model(expressions, dialect="hive")
assert model.kind.when_matched.sql() == expected_when_matched
assert len(model.kind.when_matched) == 1
assert model.kind.when_matched[0].sql() == expected_when_matched

model = SqlModel.parse_raw(model.json())
assert model.kind.when_matched.sql() == expected_when_matched
assert len(model.kind.when_matched) == 1
assert model.kind.when_matched[0].sql() == expected_when_matched


def test_when_matched_multiple():
expressions = d.parse(
"""
MODEL (
name db.employees,
kind INCREMENTAL_BY_UNIQUE_KEY (
unique_key name,
when_matched WHEN MATCHED AND source.x = 1 THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary),
WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary)
)
);
SELECT 'name' AS name, 1 AS salary;
"""
)

expected_when_matched = [
"WHEN MATCHED AND __MERGE_SOURCE__.x = 1 THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary)",
"WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary)",
]

model = load_sql_based_model(expressions, dialect="hive")
assert len(model.kind.when_matched) == 2
assert model.kind.when_matched[0].sql() == expected_when_matched[0]
assert model.kind.when_matched[1].sql() == expected_when_matched[1]

model = SqlModel.parse_raw(model.json())
assert len(model.kind.when_matched) == 2
assert model.kind.when_matched[0].sql() == expected_when_matched[0]
assert model.kind.when_matched[1].sql() == expected_when_matched[1]


def test_default_catalog_sql(assert_exp_eq):
Expand Down Expand Up @@ -5438,7 +5472,35 @@ def test_model_kind_to_expression():
.sql()
== """INCREMENTAL_BY_UNIQUE_KEY (
unique_key ("a"),
when_matched WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b),
when_matched ARRAY(WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b)),
batch_concurrency 1,
forward_only FALSE,
disable_restatement FALSE,
on_destructive_change 'ERROR'
)"""
)

assert (
load_sql_based_model(
d.parse(
"""
MODEL (
name db.table,
kind INCREMENTAL_BY_UNIQUE_KEY(
unique_key a,
when_matched WHEN MATCHED AND source.x = 1 THEN UPDATE SET target.b = COALESCE(source.b, target.b),
WHEN MATCHED THEN UPDATE SET target.b = COALESCE(source.b, target.b)
),
);
SELECT a, b
"""
)
)
.kind.to_expression()
.sql()
== """INCREMENTAL_BY_UNIQUE_KEY (
unique_key ("a"),
when_matched ARRAY(WHEN MATCHED AND __MERGE_SOURCE__.x = 1 THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b), WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b)),
batch_concurrency 1,
forward_only FALSE,
disable_restatement FALSE,
Expand Down
Loading

0 comments on commit bf89546

Please sign in to comment.