Skip to content

Commit

Permalink
fix: load multi repo models earlier to ensure schema is correct
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Feb 4, 2025
1 parent 129326d commit b1332e8
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 48 deletions.
2 changes: 1 addition & 1 deletion examples/multi/repo_1/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ gateways:
local:
connection:
type: duckdb
database: db.db
database: db.duckdb

memory:
connection:
Expand Down
4 changes: 2 additions & 2 deletions examples/multi/repo_2/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ gateways:
local:
connection:
type: duckdb
database: db.db
database: db.duckdb

memory:
connection:
Expand All @@ -13,4 +13,4 @@ gateways:
default_gateway: local

model_defaults:
dialect: 'duckdb'
dialect: 'duckdb'
7 changes: 7 additions & 0 deletions examples/multi/repo_2/models/e.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
MODEL (
name silver.e
);

SELECT
* EXCEPT(dup)
FROM bronze.a
70 changes: 35 additions & 35 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def __init__(
self.configs = (
config if isinstance(config, dict) else load_configs(config, self.CONFIG_TYPE, paths)
)
self._projects = {config.project for config in self.configs.values()}
self.dag: DAG[str] = DAG()
self._models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
self._audits: UniqueKeyDict[str, ModelAudit] = UniqueKeyDict("audits")
Expand Down Expand Up @@ -574,15 +575,38 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
self._requirements.update(project.requirements)
self._excluded_requirements.update(project.excluded_requirements)

if any(self._projects):
prod = self.state_reader.get_environment(c.PROD)

if prod:
for snapshot in self.state_reader.get_snapshots(prod.snapshots).values():
if snapshot.node.project not in self._projects:
store = self._standalone_audits if snapshot.is_audit else self._models
store[snapshot.name] = snapshot.node # type: ignore

for model in self._models.values():
self.dag.add(model.fqn, model.depends_on)

# This topologically sorts the DAG & caches the result in-memory for later;
# we do it here to detect any cycles as early as possible and fail if needed
self.dag.sorted

if update_schemas:
for fqn in self.dag:
model = self._models.get(fqn) # type: ignore

if not model or not model._data_hash:
continue

# make a copy of remote models that depend on local models or in the downstream chain
# without this, a SELECT * FROM local will not propogate properly because the downstream
# model will get mutated (schema changes) but the object is the same as the remote cache
if any(
not self._models[dep]._data_hash
for dep in model.depends_on
if dep in self._models
):
self._models.update({fqn: model.copy(update={"mapping_schema": {}})})
continue

update_model_schemas(self.dag, models=self._models, context_path=self.path)

for model in self.models.values():
# The model definition can be validated correctly only after the schema is set.
model.validate_definition()
Expand Down Expand Up @@ -2105,63 +2129,39 @@ def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
def _snapshots(
self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None
) -> t.Dict[str, Snapshot]:
projects = {config.project for config in self.configs.values()}

if any(projects):
prod = self.state_reader.get_environment(c.PROD)
remote_snapshots = (
{
snapshot.name: snapshot
for snapshot in self.state_reader.get_snapshots(prod.snapshots).values()
}
if prod
else {}
)
else:
remote_snapshots = {}

local_nodes = {**(models_override or self._models), **self._standalone_audits}
nodes = local_nodes.copy()

for name, snapshot in remote_snapshots.items():
if name not in nodes and snapshot.node.project not in projects:
nodes[name] = snapshot.node

def _nodes_to_snapshots(nodes: t.Dict[str, Node]) -> t.Dict[str, Snapshot]:
snapshots: t.Dict[str, Snapshot] = {}
fingerprint_cache: t.Dict[str, SnapshotFingerprint] = {}

for node in nodes.values():
if node.fqn not in local_nodes and node.fqn in remote_snapshots:
ttl = remote_snapshots[node.fqn].ttl
else:
config = self.config_for_node(node)
ttl = config.snapshot_ttl
kwargs = {}
if node.project in self._projects:
kwargs["ttl"] = self.config_for_node(node).snapshot_ttl

snapshot = Snapshot.from_node(
node,
nodes=nodes,
cache=fingerprint_cache,
ttl=ttl,
config=self.config_for_node(node),
**kwargs,
)
snapshots[snapshot.name] = snapshot
return snapshots

nodes = {**(models_override or self._models), **self._standalone_audits}
snapshots = _nodes_to_snapshots(nodes)
stored_snapshots = self.state_reader.get_snapshots(snapshots.values())

unrestorable_snapshots = {
snapshot
for snapshot in stored_snapshots.values()
if snapshot.name in local_nodes and snapshot.unrestorable
if snapshot.name in nodes and snapshot.unrestorable
}
if unrestorable_snapshots:
for snapshot in unrestorable_snapshots:
logger.info(
"Found a unrestorable snapshot %s. Restamping the model...", snapshot.name
)
node = local_nodes[snapshot.name]
node = nodes[snapshot.name]
nodes[snapshot.name] = node.copy(
update={"stamp": f"revert to {snapshot.identifier}"}
)
Expand Down
27 changes: 22 additions & 5 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,13 +687,23 @@ def text_diff(self, other: Node, rendered: bool = False) -> str:
f"Cannot diff model '{self.name} against a non-model node '{other.name}'"
)

return d.text_diff(
text_diff = d.text_diff(
self.render_definition(render_query=rendered),
other.render_definition(render_query=rendered),
self.dialect,
other.dialect,
).strip()

if not text_diff and not rendered:
text_diff = d.text_diff(
self.render_definition(render_query=True),
other.render_definition(render_query=True),
self.dialect,
other.dialect,
).strip()

return text_diff

def set_time_format(self, default_time_format: str = c.DEFAULT_TIME_COLUMN_FORMAT) -> None:
"""Sets the default time format for a model.
Expand Down Expand Up @@ -1256,7 +1266,7 @@ def columns_to_types(self) -> t.Optional[t.Dict[str, exp.DataType]]:
return None

self._columns_to_types = {
select.output_name: select.type or exp.DataType.build("unknown")
select.output_name: select.type.copy() or exp.DataType.build("unknown")
for select in query.selects
}

Expand Down Expand Up @@ -1351,9 +1361,16 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]:
# Can't determine if there's a breaking change if we can't render the query.
return None

edits = diff(
previous_query, this_query, matchings=[(previous_query, this_query)], delta_only=True
)
if previous_query is this_query:
edits = []
else:
edits = diff(
previous_query,
this_query,
matchings=[(previous_query, this_query)],
delta_only=True,
copy=False,
)
inserted_expressions = {e.expression for e in edits if isinstance(e, Insert)}

for edit in edits:
Expand Down
2 changes: 0 additions & 2 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
from sqlmesh.core.environment import EnvironmentNamingInfo
from sqlmesh.core.config import Config

Interval = t.Tuple[int, int]
Intervals = t.List[Interval]
Expand Down Expand Up @@ -596,7 +595,6 @@ def from_node(
ttl: str = c.DEFAULT_SNAPSHOT_TTL,
version: t.Optional[str] = None,
cache: t.Optional[t.Dict[str, SnapshotFingerprint]] = None,
config: t.Optional[Config] = None,
) -> Snapshot:
"""Creates a new snapshot for a node.
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@ def test_wildcard(copy_to_temp_path: t.Callable):
parent_path = copy_to_temp_path("examples/multi")[0]

context = Context(paths=f"{parent_path}/*")
assert len(context.models) == 4
assert len(context.models) == 5


def test_duckdb_state_connection_automatic_multithreaded_mode(tmp_path):
Expand Down
5 changes: 3 additions & 2 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3837,7 +3837,7 @@ def test_multi(mocker):
)
context._new_state_sync().reset(default_catalog=context.default_catalog)
plan = context.plan_builder().build()
assert len(plan.new_snapshots) == 4
assert len(plan.new_snapshots) == 5
context.apply(plan)

adapter = context.engine_adapter
Expand All @@ -3856,12 +3856,13 @@ def test_multi(mocker):
assert set(snapshot.name for snapshot in plan.directly_modified) == {
'"memory"."bronze"."a"',
'"memory"."bronze"."b"',
'"memory"."silver"."e"',
}
assert sorted([x.name for x in list(plan.indirectly_modified.values())[0]]) == [
'"memory"."silver"."c"',
'"memory"."silver"."d"',
]
assert len(plan.missing_intervals) == 2
assert len(plan.missing_intervals) == 3
context.apply(plan)
validate_apply_basics(context, c.PROD, plan.snapshots.values())

Expand Down

0 comments on commit b1332e8

Please sign in to comment.