66from collections import defaultdict
77from dataclasses import dataclass , field
88from pathlib import Path
9- from typing import Any , Iterable , Literal , TypeAlias
9+ from typing import Any , Iterable , Literal , TypeAlias , Union
1010
1111from dbt_score .dbt_utils import dbt_ls
1212
@@ -179,6 +179,7 @@ class Model(HasColumnsMixin):
179179 tags: The list of tags attached to the model.
180180 tests: The list of tests attached to the model.
181181 depends_on: Dictionary of models/sources/macros that the model depends on.
182+ parents: The list of models, sources, and snapshots this model depends on.
182183 _raw_values: The raw values of the model (node) in the manifest.
183184 _raw_test_values: The raw test values of the model (node) in the manifest.
184185 """
@@ -204,6 +205,7 @@ class Model(HasColumnsMixin):
204205 tests : list [Test ] = field (default_factory = list )
205206 depends_on : dict [str , list [str ]] = field (default_factory = dict )
206207 constraints : list [Constraint ] = field (default_factory = list )
208+ parents : list [Union ["Model" , "Source" , "Snapshot" ]] = field (default_factory = list )
207209 _raw_values : dict [str , Any ] = field (default_factory = dict )
208210 _raw_test_values : list [dict [str , Any ]] = field (default_factory = list )
209211
@@ -416,6 +418,7 @@ class Snapshot(HasColumnsMixin):
416418 depends_on: Dictionary of models/sources/macros that the model depends on.
417419 strategy: The strategy of the snapshot.
418420 unique_key: The unique key of the snapshot.
421+ parents: The list of models, sources, and snapshots this snapshot depends on.
419422 _raw_values: The raw values of the snapshot (node) in the manifest.
420423 _raw_test_values: The raw test values of the snapshot (node) in the manifest.
421424 """
@@ -440,6 +443,7 @@ class Snapshot(HasColumnsMixin):
440443 depends_on : dict [str , list [str ]] = field (default_factory = dict )
441444 strategy : str | None = None
442445 unique_key : list [str ] | None = None
446+ parents : list [Union ["Model" , "Source" , "Snapshot" ]] = field (default_factory = list )
443447 _raw_values : dict [str , Any ] = field (default_factory = dict )
444448 _raw_test_values : list [dict [str , Any ]] = field (default_factory = list )
445449
@@ -508,15 +512,16 @@ def __init__(self, file_path: Path, select: Iterable[str] | None = None):
508512 if source_values ["package_name" ] == self .project_name
509513 }
510514
511- self .models : list [ Model ] = []
515+ self .models : dict [ str , Model ] = {}
512516 self .tests : dict [str , list [dict [str , Any ]]] = defaultdict (list )
513- self .sources : list [ Source ] = []
514- self .snapshots : list [ Snapshot ] = []
517+ self .sources : dict [ str , Source ] = {}
518+ self .snapshots : dict [ str , Snapshot ] = {}
515519
516520 self ._reindex_tests ()
517521 self ._load_models ()
518522 self ._load_sources ()
519523 self ._load_snapshots ()
524+ self ._populate_parents ()
520525
521526 if select :
522527 self ._filter_evaluables (select )
@@ -529,21 +534,21 @@ def _load_models(self) -> None:
529534 for node_id , node_values in self .raw_nodes .items ():
530535 if node_values .get ("resource_type" ) == "model" :
531536 model = Model .from_node (node_values , self .tests .get (node_id , []))
532- self .models . append ( model )
537+ self .models [ node_id ] = model
533538
534539 def _load_sources (self ) -> None :
535540 """Load the sources from the manifest."""
536541 for source_id , source_values in self .raw_sources .items ():
537542 if source_values .get ("resource_type" ) == "source" :
538543 source = Source .from_node (source_values , self .tests .get (source_id , []))
539- self .sources . append ( source )
544+ self .sources [ source_id ] = source
540545
541546 def _load_snapshots (self ) -> None :
542547 """Load the snapshots from the manifest."""
543548 for node_id , node_values in self .raw_nodes .items ():
544549 if node_values .get ("resource_type" ) == "snapshot" :
545550 snapshot = Snapshot .from_node (node_values , self .tests .get (node_id , []))
546- self .snapshots . append ( snapshot )
551+ self .snapshots [ node_id ] = snapshot
547552
548553 def _reindex_tests (self ) -> None :
549554 """Index tests based on their associated evaluable."""
@@ -561,6 +566,17 @@ def _reindex_tests(self) -> None:
561566 ):
562567 self .tests [node_unique_id ].append (node_values )
563568
569+ def _populate_parents (self ) -> None :
570+ """Populate `parents` for all models and snapshots."""
571+ for node in list (self .models .values ()) + list (self .snapshots .values ()):
572+ for parent_id in node .depends_on .get ("nodes" , []):
573+ if parent_id in self .models :
574+ node .parents .append (self .models [parent_id ])
575+ elif parent_id in self .snapshots :
576+ node .parents .append (self .snapshots [parent_id ])
577+ elif parent_id in self .sources :
578+ node .parents .append (self .sources [parent_id ])
579+
564580 def _filter_evaluables (self , select : Iterable [str ]) -> None :
565581 """Filter evaluables like dbt's --select."""
566582 single_model_select = re .compile (r"[a-zA-Z0-9_]+" )
@@ -573,6 +589,8 @@ def _filter_evaluables(self, select: Iterable[str]) -> None:
573589 # Use dbt's implementation of --select
574590 selected = dbt_ls (select )
575591
576- self .models = [m for m in self .models if m .name in selected ]
577- self .sources = [s for s in self .sources if s .selector_name in selected ]
578- self .snapshots = [s for s in self .snapshots if s .name in selected ]
592+ self .models = {k : m for k , m in self .models .items () if m .name in selected }
593+ self .sources = {
594+ k : s for k , s in self .sources .items () if s .selector_name in selected
595+ }
596+ self .snapshots = {k : s for k , s in self .snapshots .items () if s .name in selected }
0 commit comments