Skip to content

Commit e19badc

Browse files
committed
address mypy errors
1 parent 5ef7cae commit e19badc

File tree

5 files changed

+77
-24
lines changed

5 files changed

+77
-24
lines changed

src/dbt_score/evaluation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from itertools import chain
6-
from typing import Type
6+
from typing import Type, cast
77

88
from dbt_score.formatters import Formatter
99
from dbt_score.models import Evaluable, ManifestLoader
@@ -57,6 +57,9 @@ def evaluate(self) -> None:
5757
for evaluable in chain(
5858
self._manifest_loader.models, self._manifest_loader.sources
5959
):
60+
# type inference on elements from `chain` is wonky
61+
# and resolves to superclass HasColumnsMixin
62+
evaluable = cast(Evaluable, evaluable)
6063
self.results[evaluable] = {}
6164
for rule in rules:
6265
try:

src/dbt_score/more_itertools.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,37 @@
11
"""Vendored utility functions from https://github.com/more-itertools/more-itertools."""
2+
from typing import (
3+
Callable,
4+
Iterable,
5+
Optional,
6+
TypeVar,
7+
overload,
8+
)
29

10+
_T = TypeVar("_T")
11+
_U = TypeVar("_U")
312

4-
def first_true(iterable, default=None, pred=None):
13+
14+
@overload
15+
def first_true(
16+
iterable: Iterable[_T], *, pred: Callable[[_T], object] | None = ...
17+
) -> _T | None:
18+
...
19+
20+
21+
@overload
22+
def first_true(
23+
iterable: Iterable[_T],
24+
default: _U,
25+
pred: Callable[[_T], object] | None = ...,
26+
) -> _T | _U:
27+
...
28+
29+
30+
def first_true(
31+
iterable: Iterable[_T],
32+
default: Optional[_U] = None,
33+
pred: Optional[Callable[[_T], object]] = None,
34+
) -> _T | _U | None:
535
"""Returns the first true value in the iterable.
636
737
If no true value is found, returns *default*

src/dbt_score/rule.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,17 @@
44
import typing
55
from dataclasses import dataclass, field
66
from enum import Enum
7-
from typing import Any, Callable, Iterable, Type, TypeAlias, overload
8-
9-
from dbt_score.models import Evaluable
7+
from typing import (
8+
Any,
9+
Callable,
10+
Iterable,
11+
Type,
12+
TypeAlias,
13+
cast,
14+
overload,
15+
)
16+
17+
from dbt_score.models import Evaluable, Model, Source
1018
from dbt_score.more_itertools import first_true
1119
from dbt_score.rule_filter import RuleFilter
1220

@@ -55,7 +63,9 @@ class RuleViolation:
5563
message: str | None = None
5664

5765

58-
RuleEvaluationType: TypeAlias = Callable[[Evaluable], RuleViolation | None]
66+
ModelRuleEvaluationType: TypeAlias = Callable[[Model], RuleViolation | None]
67+
SourceRuleEvaluationType: TypeAlias = Callable[[Source], RuleViolation | None]
68+
RuleEvaluationType: TypeAlias = ModelRuleEvaluationType | SourceRuleEvaluationType
5969

6070

6171
class Rule:
@@ -66,7 +76,7 @@ class Rule:
6676
rule_filter_names: list[str]
6777
rule_filters: frozenset[RuleFilter] = frozenset()
6878
default_config: typing.ClassVar[dict[str, Any]] = {}
69-
resource_type: typing.ClassVar[Evaluable]
79+
resource_type: typing.ClassVar[type[Evaluable]]
7080

7181
def __init__(self, rule_config: RuleConfig | None = None) -> None:
7282
"""Initialize the rule."""
@@ -85,7 +95,7 @@ def __init_subclass__(cls, **kwargs) -> None: # type: ignore
8595
cls._validate_rule_filters()
8696

8797
@classmethod
88-
def _validate_rule_filters(cls):
98+
def _validate_rule_filters(cls) -> None:
8999
for rule_filter in cls.rule_filters:
90100
if rule_filter.resource_type != cls.resource_type:
91101
raise TypeError(
@@ -111,7 +121,8 @@ def _introspect_resource_type(cls) -> Type[Evaluable]:
111121
"annotated Model or Source argument."
112122
)
113123

114-
return resource_type_argument.annotation
124+
resource_type = cast(type[Evaluable], resource_type_argument.annotation)
125+
return resource_type
115126

116127
def process_config(self, rule_config: RuleConfig) -> None:
117128
"""Process the rule config."""
@@ -178,7 +189,12 @@ def __hash__(self) -> int:
178189

179190

180191
@overload
181-
def rule(__func: RuleEvaluationType) -> Type[Rule]:
192+
def rule(__func: ModelRuleEvaluationType) -> Type[Rule]:
193+
...
194+
195+
196+
@overload
197+
def rule(__func: SourceRuleEvaluationType) -> Type[Rule]:
182198
...
183199

184200

@@ -214,9 +230,7 @@ def rule(
214230
rule_filters: Set of RuleFilter that filters the items that the rule applies to.
215231
"""
216232

217-
def decorator_rule(
218-
func: RuleEvaluationType,
219-
) -> Type[Rule]:
233+
def decorator_rule(func: RuleEvaluationType) -> Type[Rule]:
220234
"""Decorator function."""
221235
if func.__doc__ is None and description is None:
222236
raise AttributeError("Rule must define `description` or `func.__doc__`.")

src/dbt_score/rule_filter.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@
22

33
import inspect
44
import typing
5-
from typing import Any, Callable, Type, TypeAlias, overload
5+
from typing import Any, Callable, Type, TypeAlias, cast, overload
66

7-
from dbt_score.models import Evaluable
7+
from dbt_score.models import Evaluable, Model, Source
88
from dbt_score.more_itertools import first_true
99

10-
FilterEvaluationType: TypeAlias = Callable[[Evaluable], bool]
10+
ModelFilterEvaluationType: TypeAlias = Callable[[Model], bool]
11+
SourceFilterEvaluationType: TypeAlias = Callable[[Source], bool]
12+
FilterEvaluationType: TypeAlias = ModelFilterEvaluationType | SourceFilterEvaluationType
1113

1214

1315
class RuleFilter:
1416
"""The Filter base class."""
1517

1618
description: str
17-
resource_type: typing.ClassVar[Evaluable]
19+
resource_type: typing.ClassVar[type[Evaluable]]
1820

1921
def __init__(self) -> None:
2022
"""Initialize the filter."""
@@ -44,7 +46,8 @@ def _introspect_resource_type(cls) -> Type[Evaluable]:
4446
"annotated Model or Source argument."
4547
)
4648

47-
return resource_type_argument.annotation
49+
resource_type = cast(type[Evaluable], resource_type_argument.annotation)
50+
return resource_type
4851

4952
def evaluate(self, evaluable: Evaluable) -> bool:
5053
"""Evaluates the filter."""
@@ -65,7 +68,12 @@ def __hash__(self) -> int:
6568

6669

6770
@overload
68-
def rule_filter(__func: FilterEvaluationType) -> Type[RuleFilter]:
71+
def rule_filter(__func: ModelFilterEvaluationType) -> Type[RuleFilter]:
72+
...
73+
74+
75+
@overload
76+
def rule_filter(__func: SourceFilterEvaluationType) -> Type[RuleFilter]:
6977
...
7078

7179

@@ -96,9 +104,7 @@ def rule_filter(
96104
description: The description of the filter.
97105
"""
98106

99-
def decorator_filter(
100-
func: FilterEvaluationType,
101-
) -> Type[RuleFilter]:
107+
def decorator_filter(func: FilterEvaluationType) -> Type[RuleFilter]:
102108
"""Decorator function."""
103109
if func.__doc__ is None and description is None:
104110
raise AttributeError(

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,15 @@ def model2(raw_manifest) -> Model:
7777

7878

7979
@fixture
80-
def source1(raw_manifest) -> Model:
80+
def source1(raw_manifest) -> Source:
8181
"""Source 1."""
8282
return Source.from_node(
8383
raw_manifest["sources"]["source.package.my_source.table1"], []
8484
)
8585

8686

8787
@fixture
88-
def source2(raw_manifest) -> Model:
88+
def source2(raw_manifest) -> Source:
8989
"""Source 2."""
9090
return Source.from_node(
9191
raw_manifest["sources"]["source.package.my_source.table2"], []

0 commit comments

Comments
 (0)