Skip to content

Commit

Permalink
Refactor loader methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Themiscodes committed Sep 13, 2024
1 parent 2eaaa47 commit 0328993
Showing 1 changed file with 29 additions and 16 deletions.
45 changes: 29 additions & 16 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
)
from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.environment import Environment, EnvironmentNamingInfo
from sqlmesh.core.loader import Loader, update_model_schemas
from sqlmesh.core.loader import LoadedProject, Loader, update_model_schemas
from sqlmesh.core.macros import ExecutableOrMacro, macro
from sqlmesh.core.metric import Metric, rewrite
from sqlmesh.core.model import Model
Expand Down Expand Up @@ -541,14 +541,17 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
"""Load all files in the context's path."""
load_start_ts = time.perf_counter()

projects = []
if self._dbt_loader and self.dbt_configs:
with sys_path(*self.dbt_configs):
projects.append(self._dbt_loader.load(self, update_schemas))

if self._sqlmesh_loader and self.sqlmesh_configs:
with sys_path(*self.sqlmesh_configs):
projects.append(self._sqlmesh_loader.load(self, update_schemas))
projects = [
project
for loader, configs in [
(self._dbt_loader, self.dbt_configs),
(self._sqlmesh_loader, self.sqlmesh_configs),
]
for project in [
self._load_factory("project", loader, configs, update_schemas=update_schemas)
]
if isinstance(project, LoadedProject)
]

self._standalone_audits.clear()
self._audits.clear()
Expand Down Expand Up @@ -640,13 +643,8 @@ def run(

if not self._loaded:
# Signals should be loaded to run correctly.
if self._sqlmesh_loader:
with sys_path(*self.sqlmesh_configs):
self._sqlmesh_loader.load_signals(self)

if self._dbt_loader:
with sys_path(*self.dbt_configs):
self._dbt_loader.load_signals(self)
self._load_factory("signals", self._dbt_loader, self.dbt_configs)
self._load_factory("signals", self._sqlmesh_loader, self.sqlmesh_configs)

success = False
try:
Expand Down Expand Up @@ -2086,6 +2084,21 @@ def _register_notification_targets(self) -> None:
event_notifications, user_notification_targets, username=self.config.username
)

def _load_factory(
self,
method: str,
loader: t.Optional[Loader] = None,
configs: t.Optional[t.Dict[Path, C]] = None,
**kwargs: t.Any,
) -> LoadedProject | None:
if loader and configs:
with sys_path(*configs):
if method == "project":
return loader.load(self, **kwargs)
elif method == "signals":
loader.load_signals(self, **kwargs)
return None


class Context(GenericContext[Config]):
CONFIG_TYPE = Config

0 comments on commit 0328993

Please sign in to comment.