Skip to content
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to

## [Unreleased]

- Add `parents` to models and snapshots, allowing access to parent nodes. (#109)

## [0.11.0] - 2025-04-04

- Improve documentation on rule filters. (#93)
Expand Down
6 changes: 3 additions & 3 deletions src/dbt_score/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def evaluate(self) -> None:
rules = self._rule_registry.rules.values()

for evaluable in chain(
self._manifest_loader.models,
self._manifest_loader.sources,
self._manifest_loader.snapshots,
self._manifest_loader.models.values(),
self._manifest_loader.sources.values(),
self._manifest_loader.snapshots.values(),
):
# type inference on elements from `chain` is wonky
# and resolves to superclass HasColumnsMixin
Expand Down
38 changes: 28 additions & 10 deletions src/dbt_score/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Iterable, Literal, TypeAlias
from typing import Any, Iterable, Literal, TypeAlias, Union

from dbt_score.dbt_utils import dbt_ls

Expand Down Expand Up @@ -179,6 +179,7 @@ class Model(HasColumnsMixin):
tags: The list of tags attached to the model.
tests: The list of tests attached to the model.
depends_on: Dictionary of models/sources/macros that the model depends on.
parents: The list of models, sources, and snapshots this model depends on.
_raw_values: The raw values of the model (node) in the manifest.
_raw_test_values: The raw test values of the model (node) in the manifest.
"""
Expand All @@ -204,6 +205,7 @@ class Model(HasColumnsMixin):
tests: list[Test] = field(default_factory=list)
depends_on: dict[str, list[str]] = field(default_factory=dict)
constraints: list[Constraint] = field(default_factory=list)
parents: list[Union["Model", "Source", "Snapshot"]] = field(default_factory=list)
_raw_values: dict[str, Any] = field(default_factory=dict)
_raw_test_values: list[dict[str, Any]] = field(default_factory=list)

Expand Down Expand Up @@ -416,6 +418,7 @@ class Snapshot(HasColumnsMixin):
depends_on: Dictionary of models/sources/macros that the model depends on.
strategy: The strategy of the snapshot.
unique_key: The unique key of the snapshot.
parents: The list of models, sources, and snapshots this snapshot depends on.
_raw_values: The raw values of the snapshot (node) in the manifest.
_raw_test_values: The raw test values of the snapshot (node) in the manifest.
"""
Expand All @@ -440,6 +443,7 @@ class Snapshot(HasColumnsMixin):
depends_on: dict[str, list[str]] = field(default_factory=dict)
strategy: str | None = None
unique_key: list[str] | None = None
parents: list[Union["Model", "Source", "Snapshot"]] = field(default_factory=list)
_raw_values: dict[str, Any] = field(default_factory=dict)
_raw_test_values: list[dict[str, Any]] = field(default_factory=list)

Expand Down Expand Up @@ -508,15 +512,16 @@ def __init__(self, file_path: Path, select: Iterable[str] | None = None):
if source_values["package_name"] == self.project_name
}

self.models: list[Model] = []
self.models: dict[str, Model] = {}
self.tests: dict[str, list[dict[str, Any]]] = defaultdict(list)
self.sources: list[Source] = []
self.snapshots: list[Snapshot] = []
self.sources: dict[str, Source] = {}
self.snapshots: dict[str, Snapshot] = {}

self._reindex_tests()
self._load_models()
self._load_sources()
self._load_snapshots()
self._populate_parents()

if select:
self._filter_evaluables(select)
Expand All @@ -529,21 +534,21 @@ def _load_models(self) -> None:
for node_id, node_values in self.raw_nodes.items():
if node_values.get("resource_type") == "model":
model = Model.from_node(node_values, self.tests.get(node_id, []))
self.models.append(model)
self.models[node_id] = model

def _load_sources(self) -> None:
"""Load the sources from the manifest."""
for source_id, source_values in self.raw_sources.items():
if source_values.get("resource_type") == "source":
source = Source.from_node(source_values, self.tests.get(source_id, []))
self.sources.append(source)
self.sources[source_id] = source

def _load_snapshots(self) -> None:
"""Load the snapshots from the manifest."""
for node_id, node_values in self.raw_nodes.items():
if node_values.get("resource_type") == "snapshot":
snapshot = Snapshot.from_node(node_values, self.tests.get(node_id, []))
self.snapshots.append(snapshot)
self.snapshots[node_id] = snapshot

def _reindex_tests(self) -> None:
"""Index tests based on their associated evaluable."""
Expand All @@ -561,6 +566,17 @@ def _reindex_tests(self) -> None:
):
self.tests[node_unique_id].append(node_values)

def _populate_parents(self) -> None:
"""Populate `parents` for all models and snapshots."""
for node in list(self.models.values()) + list(self.snapshots.values()):
for parent_id in node.depends_on.get("nodes", []):
if parent_id in self.models:
node.parents.append(self.models[parent_id])
elif parent_id in self.snapshots:
node.parents.append(self.snapshots[parent_id])
elif parent_id in self.sources:
node.parents.append(self.sources[parent_id])
Comment on lines +569 to +578
Copy link
Contributor

@sercancicek sercancicek Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _populate_parents(self) -> None:
"""Populate `parents` for all models and snapshots."""
for node in list(self.models.values()) + list(self.snapshots.values()):
for parent_id in node.depends_on.get("nodes", []):
if parent_id in self.models:
node.parents.append(self.models[parent_id])
elif parent_id in self.snapshots:
node.parents.append(self.snapshots[parent_id])
elif parent_id in self.sources:
node.parents.append(self.sources[parent_id])
def _populate_parents(self) -> None:
"""Populate `parents` for all models and snapshots."""
all_parents = {**self.models, **self.snapshots, **self.sources}
for node in list(self.models.values()) + list(self.snapshots.values()):
for parent_id in node.depends_on.get("nodes", []):
if parent := all_parents.get(parent_id):
node.parents.append(parent)

IMO, this looks easier to read and maintain but just a suggestion :)


def _filter_evaluables(self, select: Iterable[str]) -> None:
"""Filter evaluables like dbt's --select."""
single_model_select = re.compile(r"[a-zA-Z0-9_]+")
Expand All @@ -573,6 +589,8 @@ def _filter_evaluables(self, select: Iterable[str]) -> None:
# Use dbt's implementation of --select
selected = dbt_ls(select)

self.models = [m for m in self.models if m.name in selected]
self.sources = [s for s in self.sources if s.selector_name in selected]
self.snapshots = [s for s in self.snapshots if s.name in selected]
self.models = {k: m for k, m in self.models.items() if m.name in selected}
self.sources = {
k: s for k, s in self.sources.items() if s.selector_name in selected
}
self.snapshots = {k: s for k, s in self.snapshots.items() if s.name in selected}
12 changes: 9 additions & 3 deletions tests/resources/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"alias": "snapshot1_alias",
"patch_path": "/path/to/snapshot1.yml",
"tags": [],
"depends_on": {},
"depends_on": { "nodes": ["model.package.model1"] },
"language": "sql",
"access": "protected"
},
Expand Down Expand Up @@ -61,7 +61,7 @@
"alias": "snapshot2_alias",
"patch_path": "/path/to/snapshot2.yml",
"tags": [],
"depends_on": {},
"depends_on": { "nodes": ["source.package.my_source.table1"] },
"language": "sql",
"access": "protected"
},
Expand Down Expand Up @@ -96,7 +96,13 @@
"alias": "model1_alias",
"patch_path": "/path/to/model1.yml",
"tags": [],
"depends_on": {},
"depends_on": {
"nodes": [
"model.package.model2",
"source.package.my_source.table1",
"snapshot.package.snapshot2"
]
},
"language": "sql",
"access": "protected",
"group": null
Expand Down
40 changes: 20 additions & 20 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def test_evaluation_low_medium_high(
)
evaluation.evaluate()

model1 = manifest_loader.models[0]
model2 = manifest_loader.models[1]
model1 = manifest_loader.models["model.package.model1"]
model2 = manifest_loader.models["model.package.model2"]

assert evaluation.results[model1][rule_severity_low] is None
assert evaluation.results[model1][rule_severity_medium] is None
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_evaluation_critical(

evaluation.evaluate()

model2 = manifest_loader.models[1]
model2 = manifest_loader.models["model.package.model2"]

assert isinstance(evaluation.results[model2][rule_severity_critical], RuleViolation)

Expand Down Expand Up @@ -157,8 +157,8 @@ def test_evaluation_rule_with_config(
):
"""Test rule evaluation with parameters."""
manifest_loader = ManifestLoader(manifest_path)
model1 = manifest_loader.models[0]
model2 = manifest_loader.models[1]
model1 = manifest_loader.models["model.package.model1"]
model2 = manifest_loader.models["model.package.model2"]

config = Config()
config._load_toml_file(str(valid_config_path))
Expand Down Expand Up @@ -216,12 +216,12 @@ def test_evaluation_with_filter(
)
evaluation.evaluate()

model1 = manifest_loader.models[0]
model2 = manifest_loader.models[1]
source1 = manifest_loader.sources[0]
source2 = manifest_loader.sources[1]
snapshot1 = manifest_loader.snapshots[0]
snapshot2 = manifest_loader.snapshots[1]
model1 = manifest_loader.models["model.package.model1"]
model2 = manifest_loader.models["model.package.model2"]
source1 = manifest_loader.sources["source.package.my_source.table1"]
source2 = manifest_loader.sources["source.package.my_source.table2"]
snapshot1 = manifest_loader.snapshots["snapshot.package.snapshot1"]
snapshot2 = manifest_loader.snapshots["snapshot.package.snapshot2"]

assert model_rule_with_filter not in evaluation.results[model1]
assert isinstance(evaluation.results[model2][model_rule_with_filter], RuleViolation)
Expand Down Expand Up @@ -266,12 +266,12 @@ def test_evaluation_with_class_filter(
)
evaluation.evaluate()

model1 = manifest_loader.models[0]
model2 = manifest_loader.models[1]
source1 = manifest_loader.sources[0]
source2 = manifest_loader.sources[1]
snapshot1 = manifest_loader.snapshots[0]
snapshot2 = manifest_loader.snapshots[1]
model1 = manifest_loader.models["model.package.model1"]
model2 = manifest_loader.models["model.package.model2"]
source1 = manifest_loader.sources["source.package.my_source.table1"]
source2 = manifest_loader.sources["source.package.my_source.table2"]
snapshot1 = manifest_loader.snapshots["snapshot.package.snapshot1"]
snapshot2 = manifest_loader.snapshots["snapshot.package.snapshot2"]

assert model_class_rule_with_filter not in evaluation.results[model1]
assert isinstance(
Expand Down Expand Up @@ -318,9 +318,9 @@ def test_evaluation_with_models_and_sources(
)
evaluation.evaluate()

model1 = manifest_loader.models[0]
source1 = manifest_loader.sources[0]
snapshot1 = manifest_loader.snapshots[0]
model1 = manifest_loader.models["model.package.model1"]
source1 = manifest_loader.sources["source.package.my_source.table1"]
snapshot1 = manifest_loader.snapshots["snapshot.package.snapshot1"]

assert decorator_rule in evaluation.results[model1]
assert decorator_rule_source not in evaluation.results[model1]
Expand Down
28 changes: 22 additions & 6 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def test_manifest_load(mock_read_text, raw_manifest):
and node["package_name"] == raw_manifest["metadata"]["project_name"]
]
)
assert loader.models[0].tests[0].name == "test2"
assert loader.models[0].tests[1].name == "test4"
assert loader.models[0].columns[0].tests[0].name == "test1"
assert loader.models["model.package.model1"].tests[0].name == "test2"
assert loader.models["model.package.model1"].tests[1].name == "test4"
assert loader.models["model.package.model1"].columns[0].tests[0].name == "test1"

assert len(loader.sources) == len(
[
Expand All @@ -30,7 +30,23 @@ def test_manifest_load(mock_read_text, raw_manifest):
if source["package_name"] == raw_manifest["metadata"]["project_name"]
]
)
assert loader.sources[0].tests[0].name == "source_test1"
assert (
loader.sources["source.package.my_source.table1"].tests[0].name
== "source_test1"
)

assert loader.snapshots["snapshot.package.snapshot1"].parents == [
loader.models["model.package.model1"]
]
assert loader.models["model.package.model1"].parents == [
loader.models["model.package.model2"],
loader.sources["source.package.my_source.table1"],
loader.snapshots["snapshot.package.snapshot2"],
]
assert loader.models["model.package.model2"].parents == []
assert loader.snapshots["snapshot.package.snapshot2"].parents == [
loader.sources["source.package.my_source.table1"]
]


@patch("dbt_score.models.Path.read_text")
Expand All @@ -39,7 +55,7 @@ def test_manifest_select_models_simple(mock_read_text, raw_manifest):
with patch("dbt_score.models.json.loads", return_value=raw_manifest):
manifest_loader = ManifestLoader(Path("some.json"), select=["model1"])

assert [x.name for x in manifest_loader.models] == ["model1"]
assert [x.name for x in manifest_loader.models.values()] == ["model1"]


@patch("dbt_score.models.Path.read_text")
Expand All @@ -50,7 +66,7 @@ def test_manifest_select_models_dbt_ls(mock_dbt_ls, mock_read_text, raw_manifest
with patch("dbt_score.models.json.loads", return_value=raw_manifest):
manifest_loader = ManifestLoader(Path("some.json"), select=["+model1"])

assert [x.name for x in manifest_loader.models] == ["model1"]
assert [x.name for x in manifest_loader.models.values()] == ["model1"]
mock_dbt_ls.assert_called_once_with(["+model1"])


Expand Down