diff --git a/examples/multi/repo_1/config.yaml b/examples/multi/repo_1/config.yaml index f4e111275..6cce77d27 100644 --- a/examples/multi/repo_1/config.yaml +++ b/examples/multi/repo_1/config.yaml @@ -4,7 +4,7 @@ gateways: local: connection: type: duckdb - database: db.db + database: db.duckdb memory: connection: diff --git a/examples/multi/repo_2/config.yaml b/examples/multi/repo_2/config.yaml index 6bd2063a8..0a127b2e7 100644 --- a/examples/multi/repo_2/config.yaml +++ b/examples/multi/repo_2/config.yaml @@ -4,7 +4,7 @@ gateways: local: connection: type: duckdb - database: db.db + database: db.duckdb memory: connection: @@ -13,4 +13,4 @@ gateways: default_gateway: local model_defaults: - dialect: 'duckdb' \ No newline at end of file + dialect: 'duckdb' diff --git a/examples/multi/repo_2/models/e.sql b/examples/multi/repo_2/models/e.sql new file mode 100644 index 000000000..34d079332 --- /dev/null +++ b/examples/multi/repo_2/models/e.sql @@ -0,0 +1,7 @@ +MODEL ( + name silver.e +); + +SELECT + * EXCEPT(dup) +FROM bronze.a diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 1a1beadaa..120ee8f49 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -334,6 +334,7 @@ def __init__( self.configs = ( config if isinstance(config, dict) else load_configs(config, self.CONFIG_TYPE, paths) ) + self._projects = {config.project for config in self.configs.values()} self.dag: DAG[str] = DAG() self._models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") self._audits: UniqueKeyDict[str, ModelAudit] = UniqueKeyDict("audits") @@ -574,15 +575,38 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: self._requirements.update(project.requirements) self._excluded_requirements.update(project.excluded_requirements) + if any(self._projects): + prod = self.state_reader.get_environment(c.PROD) + + if prod: + for snapshot in self.state_reader.get_snapshots(prod.snapshots).values(): + if snapshot.node.project not in self._projects: + store = self._standalone_audits if snapshot.is_audit else self._models + store[snapshot.name] = snapshot.node # type: ignore + for model in self._models.values(): self.dag.add(model.fqn, model.depends_on) - # This topologically sorts the DAG & caches the result in-memory for later; - # we do it here to detect any cycles as early as possible and fail if needed - self.dag.sorted - if update_schemas: + for fqn in self.dag: + model = self._models.get(fqn) # type: ignore + + if not model or not model._data_hash: + continue + + # make a copy of remote models that depend on local models or in the downstream chain + # without this, a SELECT * FROM local will not propogate properly because the downstream + # model will get mutated (schema changes) but the object is the same as the remote cache + if any( + not self._models[dep]._data_hash + for dep in model.depends_on + if dep in self._models + ): + self._models.update({fqn: model.copy(update={"mapping_schema": {}})}) + continue + update_model_schemas(self.dag, models=self._models, context_path=self.path) + for model in self.models.values(): # The model definition can be validated correctly only after the schema is set. model.validate_definition() @@ -2105,63 +2129,39 @@ def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter: def _snapshots( self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None ) -> t.Dict[str, Snapshot]: - projects = {config.project for config in self.configs.values()} - - if any(projects): - prod = self.state_reader.get_environment(c.PROD) - remote_snapshots = ( - { - snapshot.name: snapshot - for snapshot in self.state_reader.get_snapshots(prod.snapshots).values() - } - if prod - else {} - ) - else: - remote_snapshots = {} - - local_nodes = {**(models_override or self._models), **self._standalone_audits} - nodes = local_nodes.copy() - - for name, snapshot in remote_snapshots.items(): - if name not in nodes and snapshot.node.project not in projects: - nodes[name] = snapshot.node - def _nodes_to_snapshots(nodes: t.Dict[str, Node]) -> t.Dict[str, Snapshot]: snapshots: t.Dict[str, Snapshot] = {} fingerprint_cache: t.Dict[str, SnapshotFingerprint] = {} for node in nodes.values(): - if node.fqn not in local_nodes and node.fqn in remote_snapshots: - ttl = remote_snapshots[node.fqn].ttl - else: - config = self.config_for_node(node) - ttl = config.snapshot_ttl + kwargs = {} + if node.project in self._projects: + kwargs["ttl"] = self.config_for_node(node).snapshot_ttl snapshot = Snapshot.from_node( node, nodes=nodes, cache=fingerprint_cache, - ttl=ttl, - config=self.config_for_node(node), + **kwargs, ) snapshots[snapshot.name] = snapshot return snapshots + nodes = {**(models_override or self._models), **self._standalone_audits} snapshots = _nodes_to_snapshots(nodes) stored_snapshots = self.state_reader.get_snapshots(snapshots.values()) unrestorable_snapshots = { snapshot for snapshot in stored_snapshots.values() - if snapshot.name in local_nodes and snapshot.unrestorable + if snapshot.name in nodes and snapshot.unrestorable } if unrestorable_snapshots: for snapshot in unrestorable_snapshots: logger.info( "Found a unrestorable snapshot %s. Restamping the model...", snapshot.name ) - node = local_nodes[snapshot.name] + node = nodes[snapshot.name] nodes[snapshot.name] = node.copy( update={"stamp": f"revert to {snapshot.identifier}"} ) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index f0925de34..3269b2240 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -687,13 +687,23 @@ def text_diff(self, other: Node, rendered: bool = False) -> str: f"Cannot diff model '{self.name} against a non-model node '{other.name}'" ) - return d.text_diff( + text_diff = d.text_diff( self.render_definition(render_query=rendered), other.render_definition(render_query=rendered), self.dialect, other.dialect, ).strip() + if not text_diff and not rendered: + text_diff = d.text_diff( + self.render_definition(render_query=True), + other.render_definition(render_query=True), + self.dialect, + other.dialect, + ).strip() + + return text_diff + def set_time_format(self, default_time_format: str = c.DEFAULT_TIME_COLUMN_FORMAT) -> None: """Sets the default time format for a model. @@ -1256,7 +1266,7 @@ def columns_to_types(self) -> t.Optional[t.Dict[str, exp.DataType]]: return None self._columns_to_types = { - select.output_name: select.type or exp.DataType.build("unknown") + select.output_name: select.type.copy() or exp.DataType.build("unknown") for select in query.selects } @@ -1351,9 +1361,16 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]: # Can't determine if there's a breaking change if we can't render the query. return None - edits = diff( - previous_query, this_query, matchings=[(previous_query, this_query)], delta_only=True - ) + if previous_query is this_query: + edits = [] + else: + edits = diff( + previous_query, + this_query, + matchings=[(previous_query, this_query)], + delta_only=True, + copy=False, + ) inserted_expressions = {e.expression for e in edits if isinstance(e, Insert)} for edit in edits: diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 24ba8379b..3d58ef31d 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -46,7 +46,6 @@ if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType from sqlmesh.core.environment import EnvironmentNamingInfo - from sqlmesh.core.config import Config Interval = t.Tuple[int, int] Intervals = t.List[Interval] @@ -596,7 +595,6 @@ def from_node( ttl: str = c.DEFAULT_SNAPSHOT_TTL, version: t.Optional[str] = None, cache: t.Optional[t.Dict[str, SnapshotFingerprint]] = None, - config: t.Optional[Config] = None, ) -> Snapshot: """Creates a new snapshot for a node. diff --git a/tests/core/test_context.py b/tests/core/test_context.py index c7b21df93..8bd0ca342 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -1119,7 +1119,7 @@ def test_wildcard(copy_to_temp_path: t.Callable): parent_path = copy_to_temp_path("examples/multi")[0] context = Context(paths=f"{parent_path}/*") - assert len(context.models) == 4 + assert len(context.models) == 5 def test_duckdb_state_connection_automatic_multithreaded_mode(tmp_path): diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 2537e03cb..ae1863064 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -3837,7 +3837,7 @@ def test_multi(mocker): ) context._new_state_sync().reset(default_catalog=context.default_catalog) plan = context.plan_builder().build() - assert len(plan.new_snapshots) == 4 + assert len(plan.new_snapshots) == 5 context.apply(plan) adapter = context.engine_adapter @@ -3856,12 +3856,13 @@ def test_multi(mocker): assert set(snapshot.name for snapshot in plan.directly_modified) == { '"memory"."bronze"."a"', '"memory"."bronze"."b"', + '"memory"."silver"."e"', } assert sorted([x.name for x in list(plan.indirectly_modified.values())[0]]) == [ '"memory"."silver"."c"', '"memory"."silver"."d"', ] - assert len(plan.missing_intervals) == 2 + assert len(plan.missing_intervals) == 3 context.apply(plan) validate_apply_basics(context, c.PROD, plan.snapshots.values())