Skip to content

Commit

Permalink
Basic working dagster asset
Browse files Browse the repository at this point in the history
  • Loading branch information
ravenac95 committed Aug 27, 2024
1 parent becaf09 commit cb64aa1
Show file tree
Hide file tree
Showing 13 changed files with 260 additions and 78 deletions.
1 change: 1 addition & 0 deletions dagster_sqlmesh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

from .asset import *
from .config import *
from .resource import *
10 changes: 4 additions & 6 deletions dagster_sqlmesh/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@
MultiAssetResponse = Iterable[Union[AssetCheckResult, AssetMaterialization]]


# Define a SQLMesh Resource
class SQLMeshResource:
pass


@dataclass(kw_only=True)
class SQLMeshParsedFQN:
catalog: str
Expand Down Expand Up @@ -164,7 +159,10 @@ class SQLMeshController:
context: Context

def add_event_handler(self, handler: ConsoleEventHandler):
self.console.listen(handler)
return self.console.add_handler(handler)

def remove_event_handler(self, handler_id: str):
return self.console.remove_handler(handler_id)


def debug_events(ev: ConsoleEvent):
Expand Down
19 changes: 17 additions & 2 deletions dagster_sqlmesh/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from typing import Optional
from typing import Optional, Dict, Any
from dataclasses import dataclass

from dagster import Config
from sqlmesh.core.config import Config as MeshConfig
from pydantic import Field


@dataclass
class ConfigOverride:
config_as_dict: Dict

def dict(self):
return self.config_as_dict


class SQLMeshContextConfig(Config):
"""A very basic sqlmesh configuration. Currently you cannot specify the
sqlmesh configuration entirely from dagster. It is intended that your
Expand All @@ -15,4 +24,10 @@ class SQLMeshContextConfig(Config):

path: str
gateway: str
sqlmesh_config: Optional[MeshConfig] = Field(default_factory=lambda: None)
config_override: Optional[Dict[str, Any]] = Field(default_factory=lambda: None)

@property
def sqlmesh_config(self):
if self.config_override:
return MeshConfig.parse_obj(self.config_override)
return None
12 changes: 6 additions & 6 deletions dagster_sqlmesh/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)

from dagster_sqlmesh.config import SQLMeshContextConfig
from dagster_sqlmesh.events import StatefulConsoleEventHandler, show_plan_summary
from dagster_sqlmesh.events import ConsoleRecorder, show_plan_summary
from dagster_sqlmesh.asset import setup_sqlmesh_controller

logger = logging.getLogger(__name__)
Expand All @@ -46,7 +46,6 @@ def sample_sqlmesh_project():
os.remove(os.path.join(project_dir, "db.db"))

# Initialize the "source" data

yield str(project_dir)


Expand Down Expand Up @@ -95,7 +94,7 @@ def append_to_test_source(self, df: polars.DataFrame):
"""
)

def sqlmesh_plan(
def run(
self,
*,
environment: str,
Expand All @@ -107,7 +106,7 @@ def sqlmesh_plan(
restate_models: Optional[List[str]] = None,
):
controller = self.create_controller(enable_debug_console=enable_debug_console)
controller.add_event_handler(StatefulConsoleEventHandler())
controller.add_event_handler(ConsoleRecorder())
plan_options: Dict[str, Any] = dict(
environment=environment,
enable_preview=True,
Expand All @@ -134,7 +133,7 @@ def sqlmesh_plan(
if apply:
logger.debug("making plan")
plan = builder.build()
show_plan_summary(plan, lambda x: x.is_model)
show_plan_summary(logger, plan, lambda x: x.is_model)
logger.debug("applying plan")
controller.context.apply(plan)
logger.debug("running through the scheduler")
Expand All @@ -152,8 +151,9 @@ def sample_sqlmesh_test_context(sample_sqlmesh_project: str):
default_gateway="local",
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
)
config_as_dict = config.dict()
context_config = SQLMeshContextConfig(
path=sample_sqlmesh_project, gateway="local", sqlmesh_config=config
path=sample_sqlmesh_project, gateway="local", config_override=config_as_dict
)
test_context = SQLMeshTestContext(db_path=db_path, context_config=context_config)
test_context.initialize_test_source()
Expand Down
15 changes: 10 additions & 5 deletions dagster_sqlmesh/console.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Dict, Set, Union, Callable, List
from typing import Optional, Dict, Set, Union, Callable
from dataclasses import dataclass
import uuid
import unittest
Expand Down Expand Up @@ -262,7 +262,7 @@ class ShowRowDiff:

class EventConsole(Console):
def __init__(self):
self._handlers: List[ConsoleEventHandler] = []
self._handlers: Dict[str, ConsoleEventHandler] = {}

def start_plan_evaluation(self, plan: Plan) -> None:
self.publish(StartPlanEvaluation(plan))
Expand Down Expand Up @@ -426,11 +426,16 @@ def show_row_diff(
self.publish(ShowRowDiff(row_diff, show_sample, skip_grain_check))

def publish(self, event: ConsoleEvent) -> None:
for handler in self._handlers:
for handler in self._handlers.values():
handler(event)

def listen(self, handler: ConsoleEventHandler):
self._handlers.append(handler)
def add_handler(self, handler: ConsoleEventHandler):
handler_id = str(uuid.uuid4())
self._handlers[handler_id] = handler
return handler_id

def remove_handler(self, handler_id: str):
del self._handlers[handler_id]


class DebugEventConsole(EventConsole):
Expand Down
56 changes: 34 additions & 22 deletions dagster_sqlmesh/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging

from sqlmesh.core.model import Model
from sqlmesh.core.snapshot import SnapshotInfoLike, SnapshotId
from sqlmesh.core.snapshot import SnapshotInfoLike, SnapshotId, Snapshot
from sqlmesh.core.plan import Plan

from dagster_sqlmesh import console
Expand All @@ -12,6 +12,7 @@


def show_plan_summary(
logger: logging.Logger,
plan: Plan,
snapshot_selector: Callable[[SnapshotInfoLike], bool],
ignored_snapshot_ids: Optional[Set[SnapshotId]] = None,
Expand Down Expand Up @@ -58,52 +59,63 @@ def show_plan_summary(
logger.debug(restated_snapshots)


class StatefulConsoleEventHandler:
def __init__(self, enable_unknown_event_logging: bool = True):
class ConsoleRecorder:
def __init__(
self,
log_override: Optional[logging.Logger] = None,
enable_unknown_event_logging: bool = True,
):
self.logger = log_override or logger
self._planned_models: List[Model] = []
self._updated: List[Snapshot] = []
self._successful = False
self._enable_unknown_event_logging = enable_unknown_event_logging

def __call__(self, event: console.ConsoleEvent):
match event:
case console.StartPlanEvaluation(plan):
logger.debug("Starting plan evaluation")
self.logger.debug("Starting plan evaluation")
self._show_summary_for(
plan,
lambda x: x.is_model,
)
case console.StartEvaluationProgress(
batches, environment_naming_info, default_catalog
):
logger.debug("STARTING EVALUATION")
logger.debug(batches)
logger.debug(environment_naming_info)
logger.debug(default_catalog)
self.logger.debug("STARTING EVALUATION")
self.logger.debug(batches)
self.logger.debug(environment_naming_info)
self.logger.debug(default_catalog)
case console.UpdatePromotionProgress(snapshot, promoted):
logger.debug("UPDATE PROMOTION PROGRESS")
logger.debug(snapshot)
logger.debug(promoted)
self.logger.debug("UPDATE PROMOTION PROGRESS")
self.logger.debug(snapshot)
self.logger.debug(promoted)
case console.StopPromotionProgress(success):
logger.debug("STOP PROMOTION")
logger.debug(success)
self.logger.debug("STOP PROMOTION")
self.logger.debug(success)
self._successful = True
case console.StartSnapshotEvaluationProgress(snapshot):
logger.debug("START SNAPSHOT EVALUATION")
logger.debug(snapshot.name)
self.logger.debug("START SNAPSHOT EVALUATION")
self.logger.debug(snapshot.name)
case console.UpdateSnapshotEvaluationProgress(
snapshot, batch_idx, duration_ms
):
logger.debug("UPDATE SNAPSHOT EVALUATION")
logger.debug(snapshot.name)
logger.debug(batch_idx)
logger.debug(duration_ms)
self._updated.append(snapshot)
self.logger.debug("UPDATE SNAPSHOT EVALUATION")
self.logger.debug(snapshot.name)
self.logger.debug(batch_idx)
self.logger.debug(duration_ms)
case _:
if self._enable_unknown_event_logging:
logger.debug("Unhandled event")
logger.debug(event)
self.logger.debug("Unhandled event")
self.logger.debug(event)

def _show_summary_for(
self,
plan: Plan,
snapshot_selector: Callable[[SnapshotInfoLike], bool],
ignored_snapshot_ids: Optional[Set[SnapshotId]] = None,
):
return show_plan_summary(plan, snapshot_selector, ignored_snapshot_ids)
return show_plan_summary(
self.logger, plan, snapshot_selector, ignored_snapshot_ids
)
90 changes: 85 additions & 5 deletions dagster_sqlmesh/resource.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,94 @@
from dagster import ConfigurableResource, AssetExecutionContext
import typing as t

from dagster import ConfigurableResource, AssetExecutionContext, MaterializeResult
from sqlmesh.utils.date import TimeLike
from sqlmesh.core.plan.builder import PlanBuilder
from sqlmesh.core.config import CategorizerConfig

from dagster_sqlmesh.asset import SQLMeshDagsterTranslator, setup_sqlmesh_controller
from dagster_sqlmesh.events import ConsoleRecorder
from .config import SQLMeshContextConfig


class PlanOptions(t.TypedDict):
start: t.NotRequired[TimeLike]
end: t.NotRequired[TimeLike]
execution_time: t.NotRequired[TimeLike]
create_from: t.NotRequired[str]
skip_tests: t.NotRequired[bool]
restate_models: t.NotRequired[t.Iterable[str]]
no_gaps: t.NotRequired[bool]
skip_backfill: t.NotRequired[bool]
forward_only: t.NotRequired[bool]
allow_destructive_models: t.NotRequired[t.Collection[str]]
no_auto_categorization: t.NotRequired[bool]
effective_from: t.NotRequired[TimeLike]
include_unmodified: t.NotRequired[bool]
select_models: t.NotRequired[t.Collection[str]]
backfill_models: t.NotRequired[t.Collection[str]]
categorizer_config: t.NotRequired[CategorizerConfig]
enable_preview: t.NotRequired[bool]
run: t.NotRequired[bool]


class RunOptions(t.TypedDict):
start: t.NotRequired[TimeLike]
end: t.NotRequired[TimeLike]
execution_time: t.NotRequired[TimeLike]
skip_janitor: t.NotRequired[bool]
ignore_cron: t.NotRequired[bool]


class SQLMeshResource(ConfigurableResource):
config: SQLMeshContextConfig

def run(self, context: AssetExecutionContext):
def run(
self,
context: AssetExecutionContext,
translator: SQLMeshDagsterTranslator,
environment: str = "dev",
plan_options: t.Optional[PlanOptions] = None,
run_options: t.Optional[RunOptions] = None,
) -> t.List[MaterializeResult]:
"""Execute SQLMesh based on the configuration given"""
pass
logger = context.log
controller = self.get_controller()
controller.context.plan()
recorder = ConsoleRecorder()

recorder_handler_id = controller.add_event_handler(recorder)
logger.debug("start")
builder = t.cast(
PlanBuilder,
controller.context.plan_builder(
environment=environment,
**(plan_options or {}),
),
)
logger.debug("making plan")
plan = builder.build()
logger.debug("applying plan")
controller.context.apply(plan)
logger.debug("running through the scheduler")
controller.context.run(environment=environment, **(run_options or {}))
controller.remove_event_handler(recorder_handler_id)
controller.context.close()

materialized: t.List[MaterializeResult] = []
for updated in recorder._updated:
asset_key = translator.get_asset_key_from_model(
controller.context, updated.model
)
materialized.append(
MaterializeResult(
asset_key=asset_key,
metadata={
"updated": True,
},
)
)
logger.debug(recorder._updated)
return materialized

def plan(self, context: AssetExecutionContext):
pass
def get_controller(self):
return setup_sqlmesh_controller(self.config)
2 changes: 1 addition & 1 deletion dagster_sqlmesh/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
class DagsterSource(Signal):
# concrete implementation of abstraction method from Signal
def check_intervals(self, batch: Batch) -> bool | Batch:
""" "Filter the batch to only return the intervals for which the file exists"""
logger.debug("batches")
logger.debug(batch)
""" "Filter the batch to only return the intervals for which the file exists"""
return True


Expand Down
Loading

0 comments on commit cb64aa1

Please sign in to comment.