Skip to content

Commit

Permalink
Feat: Use dbt manifest to load dbt projects (#821)
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman authored May 8, 2023
1 parent b72f007 commit 2ce5f34
Show file tree
Hide file tree
Showing 23 changed files with 587 additions and 940 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"google-cloud-bigquery-storage",
"black==22.6.0",
"dbt-core",
"dbt-duckdb",
"Faker",
"google-auth",
"isort==5.10.1",
Expand Down
20 changes: 8 additions & 12 deletions sqlmesh/dbt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ def __init__(
self,
jinja_macros: JinjaMacroRegistry,
jinja_globals: t.Optional[t.Dict[str, t.Any]] = None,
dialect: str = "",
):
self.jinja_macros = jinja_macros
self.jinja_globals = jinja_globals or {}
self.dialect = dialect
self.jinja_globals = jinja_globals.copy() if jinja_globals else {}
self.jinja_globals["adapter"] = self

@abc.abstractmethod
def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]:
Expand Down Expand Up @@ -78,18 +77,15 @@ def quote(self, identifier: str) -> str:

def dispatch(self, name: str, package: t.Optional[str] = None) -> t.Callable:
"""Returns a dialect-specific version of a macro with the given name."""
dialect_name = f"{self.dialect}__{name}"
default_name = f"default__{name}"

target_type = self.jinja_globals["target"]["type"]
references_to_try = [
MacroReference(package=package, name=dialect_name),
MacroReference(package=package, name=default_name),
MacroReference(package=f"{package}_{target_type}", name=f"{target_type}__{name}"),
MacroReference(package=package, name=f"{target_type}__{name}"),
MacroReference(package=package, name=f"default__{name}"),
]

for reference in references_to_try:
macro_callable = self.jinja_macros.build_macro(
reference, **{**self.jinja_globals, "adapter": self}
)
macro_callable = self.jinja_macros.build_macro(reference, **self.jinja_globals)
if macro_callable is not None:
return macro_callable

Expand Down Expand Up @@ -141,7 +137,7 @@ def __init__(
):
from dbt.adapters.base.relation import Policy

super().__init__(jinja_macros, jinja_globals=jinja_globals, dialect=engine_adapter.dialect)
super().__init__(jinja_macros, jinja_globals=jinja_globals)

self.engine_adapter = engine_adapter
# All engines quote by default except Snowflake
Expand Down
176 changes: 7 additions & 169 deletions sqlmesh/dbt/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,29 @@
from enum import Enum
from pathlib import Path

from dbt.adapters.base import BaseRelation
from dbt.contracts.relation import RelationType
from jinja2 import nodes
from jinja2.exceptions import UndefinedError
from pydantic import Field, validator
from sqlglot.helper import ensure_list

from sqlmesh.core import constants as c
from sqlmesh.core import dialect as d
from sqlmesh.core.config.base import UpdateStrategy
from sqlmesh.core.model import Model
from sqlmesh.dbt.adapter import ParsetimeAdapter
from sqlmesh.dbt.column import (
ColumnConfig,
column_descriptions_to_sqlmesh,
column_types_to_sqlmesh,
yaml_to_columns,
)
from sqlmesh.dbt.common import DbtConfig, GeneralConfig, QuotingConfig, SqlStr
from sqlmesh.dbt.context import DbtContext
from sqlmesh.utils import AttributeDict
from sqlmesh.utils.conversions import ensure_bool
from sqlmesh.utils.date import date_dict
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.jinja import MacroReference, extract_macro_references
from sqlmesh.utils.jinja import MacroReference
from sqlmesh.utils.pydantic import PydanticModel

if t.TYPE_CHECKING:
from sqlmesh.dbt.context import DbtContext


BMC = t.TypeVar("BMC", bound="BaseModelConfig")


Expand All @@ -43,21 +39,17 @@ class Dependencies(PydanticModel):
macros: The references to macros
sources: The "source_name.table_name" for source tables used
refs: The table_name for models used
variables: The names of variables used, mapped to a flag that indicates whether their
definition is optional or not.
"""

macros: t.Set[MacroReference] = set()
sources: t.Set[str] = set()
refs: t.Set[str] = set()
variables: t.Set[str] = set()

def union(self, other: Dependencies) -> Dependencies:
dependencies = Dependencies()
dependencies.macros = self.macros | other.macros
dependencies.sources = self.sources | other.sources
dependencies.refs = self.refs | other.refs
dependencies.variables = self.variables | other.variables

return dependencies

Expand Down Expand Up @@ -101,7 +93,6 @@ class BaseModelConfig(GeneralConfig):
storage_format: The storage format used to store the physical table, only applicable in certain engines.
(eg. 'parquet')
path: The file path of the model
target_schema: The schema for the profile target
dependencies: The macro, source, var, and ref dependencies used to execute the model and its hooks
database: Database the model is stored in
schema: Custom schema name added to the model schema name
Expand All @@ -119,12 +110,11 @@ class BaseModelConfig(GeneralConfig):
stamp: t.Optional[str] = None
storage_format: t.Optional[str] = None
path: Path = Path()
target_schema: str = ""
dependencies: Dependencies = Dependencies()

# DBT configuration fields
schema_: str = Field("", alias="schema")
database: t.Optional[str] = None
schema_: t.Optional[str] = Field(None, alias="schema")
alias: t.Optional[str] = None
pre_hook: t.List[Hook] = Field([], alias="pre-hook")
post_hook: t.List[Hook] = Field([], alias="post-hook")
Expand Down Expand Up @@ -156,13 +146,6 @@ def _validate_bool(cls, v: str) -> bool:
def _validate_grants(cls, v: t.Dict[str, str]) -> t.Dict[str, t.List[str]]:
return {key: ensure_list(value) for key, value in v.items()}

@validator("columns", pre=True)
def _validate_columns(cls, v: t.Any) -> t.Dict[str, ColumnConfig]:
if isinstance(v, dict) and all(isinstance(col, ColumnConfig) for col in v.values()):
return v

return yaml_to_columns(v)

_FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = {
**GeneralConfig._FIELD_UPDATE_STRATEGY,
**{
Expand Down Expand Up @@ -197,7 +180,7 @@ def table_schema(self) -> str:
"""
Get the full schema name
"""
return "_".join(part for part in (self.target_schema, self.schema_) if part)
return self.schema_

@property
def table_name(self) -> str:
Expand Down Expand Up @@ -293,21 +276,6 @@ def sqlmesh_model_kwargs(self, model_context: DbtContext) -> t.Dict[str, t.Any]:
**optional_kwargs,
}

def render_config(self: BMC, context: DbtContext) -> BMC:
rendered = super().render_config(context)
rendered = ModelSqlRenderer(context, rendered).enriched_config

rendered_dependencies = rendered.dependencies
for dependency in rendered_dependencies.refs:
model = context.models.get(dependency)
if model and model.materialized == Materialization.EPHEMERAL:
rendered.dependencies = rendered.dependencies.union(
model.render_config(context).dependencies
)
rendered.dependencies.refs.discard(dependency)

return rendered

@abstractmethod
def to_sqlmesh(self, context: DbtContext) -> Model:
"""Convert DBT model into sqlmesh Model"""
Expand Down Expand Up @@ -338,135 +306,5 @@ def _context_for_dependencies(
model_context.sources = sources
model_context.seeds = seeds
model_context.models = models
model_context.variables = {
name: value
for name, value in context.variables.items()
if name in dependencies.variables
}

return model_context


class ModelSqlRenderer(t.Generic[BMC]):
def __init__(self, context: DbtContext, config: BMC):
from sqlmesh.dbt.builtin import create_builtin_globals

self.context = context
self.config = config

self._captured_dependencies: Dependencies = Dependencies()
self._rendered_sql: t.Optional[str] = None
self._enriched_config: BMC = config.copy()

self._jinja_globals = create_builtin_globals(
jinja_macros=context.jinja_macros,
jinja_globals={
**context.jinja_globals,
**date_dict(c.EPOCH, c.EPOCH, c.EPOCH),
"config": lambda *args, **kwargs: "",
"ref": self._ref,
"var": self._var,
"source": self._source,
"this": self.config.relation_info,
"model": self.config.model_function(),
"schema": self.config.table_schema,
},
engine_adapter=None,
)

# Set the adapter separately since it requires jinja globals to passed into it.
self._jinja_globals["adapter"] = ModelSqlRenderer.TrackingAdapter(
self,
context.jinja_macros,
jinja_globals=self._jinja_globals,
dialect=context.engine_adapter.dialect if context.engine_adapter else "",
)

self.jinja_env = self.context.jinja_macros.build_environment(**self._jinja_globals)

@property
def enriched_config(self) -> BMC:
if self._rendered_sql is None:
self._enriched_config = self._update_with_sql_config(self._enriched_config)
self._enriched_config.dependencies = Dependencies(
macros=extract_macro_references(self._enriched_config.all_sql)
)
self.render()
self._enriched_config.dependencies = self._enriched_config.dependencies.union(
self._captured_dependencies
)
return self._enriched_config

def render(self) -> str:
if self._rendered_sql is None:
try:
self._rendered_sql = self.jinja_env.from_string(
self._enriched_config.all_sql
).render()
except UndefinedError as e:
raise ConfigError(e.message)
return self._rendered_sql

def _update_with_sql_config(self, config: BMC) -> BMC:
def _extract_value(node: t.Any) -> t.Any:
if not isinstance(node, nodes.Node):
return node
if isinstance(node, nodes.Const):
return _extract_value(node.value)
if isinstance(node, nodes.TemplateData):
return _extract_value(node.data)
if isinstance(node, nodes.List):
return [_extract_value(val) for val in node.items]
if isinstance(node, nodes.Dict):
return {_extract_value(pair.key): _extract_value(pair.value) for pair in node.items}
if isinstance(node, nodes.Tuple):
return tuple(_extract_value(val) for val in node.items)

return self.jinja_env.from_string(nodes.Template([nodes.Output([node])])).render()

for call in self.jinja_env.parse(self._enriched_config.sql_embedded_config).find_all(
nodes.Call
):
if not isinstance(call.node, nodes.Name) or call.node.name != "config":
continue
config = config.update_with(
{kwarg.key: _extract_value(kwarg.value) for kwarg in call.kwargs}
)

return config

def _ref(self, package_name: str, model_name: t.Optional[str] = None) -> BaseRelation:
self._captured_dependencies.refs.add(package_name)
return BaseRelation.create()

def _var(self, name: str, default: t.Optional[str] = None) -> t.Any:
if default is None and name not in self.context.variables:
raise ConfigError(
f"Variable '{name}' was not found for model '{self.config.table_name}'."
)
self._captured_dependencies.variables.add(name)
return self.context.variables.get(name, default)

def _source(self, source_name: str, table_name: str) -> BaseRelation:
full_name = ".".join([source_name, table_name])
self._captured_dependencies.sources.add(full_name)
return BaseRelation.create()

class TrackingAdapter(ParsetimeAdapter):
def __init__(self, outer_self: ModelSqlRenderer, *args: t.Any, **kwargs: t.Any):
super().__init__(*args, **kwargs)
self.outer_self = outer_self
self.context = outer_self.context

def dispatch(self, name: str, package: t.Optional[str] = None) -> t.Callable:
macros = (
self.context.jinja_macros.packages.get(package, {})
if package is not None
else self.context.jinja_macros.root_macros
)
for target_name in macros:
if target_name.endswith(f"__{name}"):
self.outer_self._captured_dependencies.macros.add(
MacroReference(package=package, name=target_name)
)
return super().dispatch(name, package=package)
36 changes: 3 additions & 33 deletions sqlmesh/dbt/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,18 @@

import json
import os
import sys
import typing as t
from ast import literal_eval
from pathlib import Path

import agate
import jinja2
from dbt import version
from dbt.adapters.base import BaseRelation
from dbt.contracts.relation import Policy
from ruamel.yaml import YAMLError

from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.dbt.adapter import ParsetimeAdapter, RuntimeAdapter
from sqlmesh.dbt.context import DbtContext
from sqlmesh.dbt.package import PackageLoader
from sqlmesh.utils import AttributeDict, yaml
from sqlmesh.utils.errors import ConfigError, MacroEvalError
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReturnVal
Expand Down Expand Up @@ -250,30 +247,9 @@ def _try_literal_eval(value: str) -> t.Any:
return value


def _dbt_macro_registry() -> JinjaMacroRegistry:
registry = JinjaMacroRegistry()

try:
site_packages = next(
p for p in sys.path if "site-packages" in p and Path(p, "dbt").exists()
)
except:
return registry

for project_file in Path(site_packages).glob("dbt/include/*/dbt_project.yml"):
if project_file.parent.stem == "starter_project":
continue
context = DbtContext(project_root=project_file.parent, jinja_macros=JinjaMacroRegistry())
package = PackageLoader(context).load()
registry.add_macros(package.macro_infos, package="dbt")

return registry


DBT_MACRO_REGISTRY = _dbt_macro_registry()

BUILTIN_GLOBALS = {
"api": Api(),
"dbt_version": version.__version__,
"env_var": env_var,
"exceptions": Exceptions(),
"flags": Flags(),
Expand Down Expand Up @@ -367,13 +343,7 @@ def create_builtin_globals(
}
)

builtin_globals.update(jinja_globals)
if "dbt" not in builtin_globals:
builtin_globals["dbt"] = DBT_MACRO_REGISTRY.build_environment(
**builtin_globals
).globals.get("dbt", {})

return builtin_globals
return {**builtin_globals, **jinja_globals}


def create_builtin_filters() -> t.Dict[str, t.Callable]:
Expand Down
Loading

0 comments on commit 2ce5f34

Please sign in to comment.