Skip to content

Commit b1332e8

Browse files
committed
fix: load multi repo models earlier to ensure schema is correct
1 parent 129326d commit b1332e8

File tree

8 files changed

+71
-48
lines changed

8 files changed

+71
-48
lines changed

examples/multi/repo_1/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ gateways:
44
local:
55
connection:
66
type: duckdb
7-
database: db.db
7+
database: db.duckdb
88

99
memory:
1010
connection:

examples/multi/repo_2/config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ gateways:
44
local:
55
connection:
66
type: duckdb
7-
database: db.db
7+
database: db.duckdb
88

99
memory:
1010
connection:
@@ -13,4 +13,4 @@ gateways:
1313
default_gateway: local
1414

1515
model_defaults:
16-
dialect: 'duckdb'
16+
dialect: 'duckdb'

examples/multi/repo_2/models/e.sql

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
MODEL (
2+
name silver.e
3+
);
4+
5+
SELECT
6+
* EXCEPT(dup)
7+
FROM bronze.a

sqlmesh/core/context.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def __init__(
334334
self.configs = (
335335
config if isinstance(config, dict) else load_configs(config, self.CONFIG_TYPE, paths)
336336
)
337+
self._projects = {config.project for config in self.configs.values()}
337338
self.dag: DAG[str] = DAG()
338339
self._models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
339340
self._audits: UniqueKeyDict[str, ModelAudit] = UniqueKeyDict("audits")
@@ -574,15 +575,38 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
574575
self._requirements.update(project.requirements)
575576
self._excluded_requirements.update(project.excluded_requirements)
576577

578+
if any(self._projects):
579+
prod = self.state_reader.get_environment(c.PROD)
580+
581+
if prod:
582+
for snapshot in self.state_reader.get_snapshots(prod.snapshots).values():
583+
if snapshot.node.project not in self._projects:
584+
store = self._standalone_audits if snapshot.is_audit else self._models
585+
store[snapshot.name] = snapshot.node # type: ignore
586+
577587
for model in self._models.values():
578588
self.dag.add(model.fqn, model.depends_on)
579589

580-
# This topologically sorts the DAG & caches the result in-memory for later;
581-
# we do it here to detect any cycles as early as possible and fail if needed
582-
self.dag.sorted
583-
584590
if update_schemas:
591+
for fqn in self.dag:
592+
model = self._models.get(fqn) # type: ignore
593+
594+
if not model or not model._data_hash:
595+
continue
596+
597+
# make a copy of remote models that depend on local models or in the downstream chain
598+
# without this, a SELECT * FROM local will not propogate properly because the downstream
599+
# model will get mutated (schema changes) but the object is the same as the remote cache
600+
if any(
601+
not self._models[dep]._data_hash
602+
for dep in model.depends_on
603+
if dep in self._models
604+
):
605+
self._models.update({fqn: model.copy(update={"mapping_schema": {}})})
606+
continue
607+
585608
update_model_schemas(self.dag, models=self._models, context_path=self.path)
609+
586610
for model in self.models.values():
587611
# The model definition can be validated correctly only after the schema is set.
588612
model.validate_definition()
@@ -2105,63 +2129,39 @@ def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
21052129
def _snapshots(
21062130
self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None
21072131
) -> t.Dict[str, Snapshot]:
2108-
projects = {config.project for config in self.configs.values()}
2109-
2110-
if any(projects):
2111-
prod = self.state_reader.get_environment(c.PROD)
2112-
remote_snapshots = (
2113-
{
2114-
snapshot.name: snapshot
2115-
for snapshot in self.state_reader.get_snapshots(prod.snapshots).values()
2116-
}
2117-
if prod
2118-
else {}
2119-
)
2120-
else:
2121-
remote_snapshots = {}
2122-
2123-
local_nodes = {**(models_override or self._models), **self._standalone_audits}
2124-
nodes = local_nodes.copy()
2125-
2126-
for name, snapshot in remote_snapshots.items():
2127-
if name not in nodes and snapshot.node.project not in projects:
2128-
nodes[name] = snapshot.node
2129-
21302132
def _nodes_to_snapshots(nodes: t.Dict[str, Node]) -> t.Dict[str, Snapshot]:
21312133
snapshots: t.Dict[str, Snapshot] = {}
21322134
fingerprint_cache: t.Dict[str, SnapshotFingerprint] = {}
21332135

21342136
for node in nodes.values():
2135-
if node.fqn not in local_nodes and node.fqn in remote_snapshots:
2136-
ttl = remote_snapshots[node.fqn].ttl
2137-
else:
2138-
config = self.config_for_node(node)
2139-
ttl = config.snapshot_ttl
2137+
kwargs = {}
2138+
if node.project in self._projects:
2139+
kwargs["ttl"] = self.config_for_node(node).snapshot_ttl
21402140

21412141
snapshot = Snapshot.from_node(
21422142
node,
21432143
nodes=nodes,
21442144
cache=fingerprint_cache,
2145-
ttl=ttl,
2146-
config=self.config_for_node(node),
2145+
**kwargs,
21472146
)
21482147
snapshots[snapshot.name] = snapshot
21492148
return snapshots
21502149

2150+
nodes = {**(models_override or self._models), **self._standalone_audits}
21512151
snapshots = _nodes_to_snapshots(nodes)
21522152
stored_snapshots = self.state_reader.get_snapshots(snapshots.values())
21532153

21542154
unrestorable_snapshots = {
21552155
snapshot
21562156
for snapshot in stored_snapshots.values()
2157-
if snapshot.name in local_nodes and snapshot.unrestorable
2157+
if snapshot.name in nodes and snapshot.unrestorable
21582158
}
21592159
if unrestorable_snapshots:
21602160
for snapshot in unrestorable_snapshots:
21612161
logger.info(
21622162
"Found a unrestorable snapshot %s. Restamping the model...", snapshot.name
21632163
)
2164-
node = local_nodes[snapshot.name]
2164+
node = nodes[snapshot.name]
21652165
nodes[snapshot.name] = node.copy(
21662166
update={"stamp": f"revert to {snapshot.identifier}"}
21672167
)

sqlmesh/core/model/definition.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -687,13 +687,23 @@ def text_diff(self, other: Node, rendered: bool = False) -> str:
687687
f"Cannot diff model '{self.name} against a non-model node '{other.name}'"
688688
)
689689

690-
return d.text_diff(
690+
text_diff = d.text_diff(
691691
self.render_definition(render_query=rendered),
692692
other.render_definition(render_query=rendered),
693693
self.dialect,
694694
other.dialect,
695695
).strip()
696696

697+
if not text_diff and not rendered:
698+
text_diff = d.text_diff(
699+
self.render_definition(render_query=True),
700+
other.render_definition(render_query=True),
701+
self.dialect,
702+
other.dialect,
703+
).strip()
704+
705+
return text_diff
706+
697707
def set_time_format(self, default_time_format: str = c.DEFAULT_TIME_COLUMN_FORMAT) -> None:
698708
"""Sets the default time format for a model.
699709
@@ -1256,7 +1266,7 @@ def columns_to_types(self) -> t.Optional[t.Dict[str, exp.DataType]]:
12561266
return None
12571267

12581268
self._columns_to_types = {
1259-
select.output_name: select.type or exp.DataType.build("unknown")
1269+
select.output_name: select.type.copy() or exp.DataType.build("unknown")
12601270
for select in query.selects
12611271
}
12621272

@@ -1351,9 +1361,16 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]:
13511361
# Can't determine if there's a breaking change if we can't render the query.
13521362
return None
13531363

1354-
edits = diff(
1355-
previous_query, this_query, matchings=[(previous_query, this_query)], delta_only=True
1356-
)
1364+
if previous_query is this_query:
1365+
edits = []
1366+
else:
1367+
edits = diff(
1368+
previous_query,
1369+
this_query,
1370+
matchings=[(previous_query, this_query)],
1371+
delta_only=True,
1372+
copy=False,
1373+
)
13571374
inserted_expressions = {e.expression for e in edits if isinstance(e, Insert)}
13581375

13591376
for edit in edits:

sqlmesh/core/snapshot/definition.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
if t.TYPE_CHECKING:
4747
from sqlglot.dialects.dialect import DialectType
4848
from sqlmesh.core.environment import EnvironmentNamingInfo
49-
from sqlmesh.core.config import Config
5049

5150
Interval = t.Tuple[int, int]
5251
Intervals = t.List[Interval]
@@ -596,7 +595,6 @@ def from_node(
596595
ttl: str = c.DEFAULT_SNAPSHOT_TTL,
597596
version: t.Optional[str] = None,
598597
cache: t.Optional[t.Dict[str, SnapshotFingerprint]] = None,
599-
config: t.Optional[Config] = None,
600598
) -> Snapshot:
601599
"""Creates a new snapshot for a node.
602600

tests/core/test_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1119,7 +1119,7 @@ def test_wildcard(copy_to_temp_path: t.Callable):
11191119
parent_path = copy_to_temp_path("examples/multi")[0]
11201120

11211121
context = Context(paths=f"{parent_path}/*")
1122-
assert len(context.models) == 4
1122+
assert len(context.models) == 5
11231123

11241124

11251125
def test_duckdb_state_connection_automatic_multithreaded_mode(tmp_path):

tests/core/test_integration.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3837,7 +3837,7 @@ def test_multi(mocker):
38373837
)
38383838
context._new_state_sync().reset(default_catalog=context.default_catalog)
38393839
plan = context.plan_builder().build()
3840-
assert len(plan.new_snapshots) == 4
3840+
assert len(plan.new_snapshots) == 5
38413841
context.apply(plan)
38423842

38433843
adapter = context.engine_adapter
@@ -3856,12 +3856,13 @@ def test_multi(mocker):
38563856
assert set(snapshot.name for snapshot in plan.directly_modified) == {
38573857
'"memory"."bronze"."a"',
38583858
'"memory"."bronze"."b"',
3859+
'"memory"."silver"."e"',
38593860
}
38603861
assert sorted([x.name for x in list(plan.indirectly_modified.values())[0]]) == [
38613862
'"memory"."silver"."c"',
38623863
'"memory"."silver"."d"',
38633864
]
3864-
assert len(plan.missing_intervals) == 2
3865+
assert len(plan.missing_intervals) == 3
38653866
context.apply(plan)
38663867
validate_apply_basics(context, c.PROD, plan.snapshots.values())
38673868

0 commit comments

Comments
 (0)