Skip to content

Commit f12c3fd

Browse files
committed
fix: typeguard
1 parent 475603a commit f12c3fd

File tree

3 files changed

+70
-5
lines changed

3 files changed

+70
-5
lines changed

fgpyo/util/inspect.py

-2
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def get_attr_fields_dict(cls: type) -> Dict[str, dataclasses.Field]: # type: ig
7171
if TYPE_CHECKING: # pragma: no cover
7272
from _typeshed import DataclassInstance
7373
else:
74-
7574
# https://github.com/python/typeshed/blob/727f3c4320d2af3af2f16695e24dd78e79b7c070/stdlib/_typeshed/__init__.pyi#L348
7675
# TODO: update the hint to `Field[Any]` when we drop support for 3.8
7776
class DataclassInstance(Protocol):
@@ -81,7 +80,6 @@ class DataclassInstance(Protocol):
8180
if TYPE_CHECKING and _use_attr: # pragma: no cover
8281
from attr import AttrsInstance
8382
else:
84-
8583
# https://github.com/python-attrs/attrs/blob/f7f317ae4c3790f23ae027db626593d50e8a4e88/src/attr/_typing_compat.pyi#L9
8684
class AttrsInstance(Protocol): # type: ignore[no-redef]
8785
__attrs_attrs__: ClassVar[Any]

fgpyo/util/metric.py

+46-3
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,15 @@
119119
import dataclasses
120120
from abc import ABC
121121
from enum import Enum
122+
from inspect import isclass
122123
from pathlib import Path
123124
from typing import Any
124125
from typing import Callable
125126
from typing import Dict
126127
from typing import Generic
127128
from typing import Iterator
128129
from typing import List
130+
from typing import TypeGuard
129131
from typing import TypeVar
130132

131133
import attr
@@ -339,12 +341,53 @@ def fast_concat(*inputs: Path, output: Path) -> None:
339341
)
340342

341343

344+
def is_dataclass_instance(metric: Metric) -> TypeGuard[inspect.DataclassInstance]:
345+
"""
346+
Test if the given metric is a dataclass instance.
347+
348+
NB: `dataclasses.is_dataclass` returns True for both dataclass instances and class objects, and
349+
we need to override the built-in function's `TypeGuard`.
350+
351+
Args:
352+
metric: An instance of a Metric.
353+
354+
Returns:
355+
True if the given metric is an instance of a dataclass-decorated Metric.
356+
False otherwise.
357+
"""
358+
return not isclass(metric) and dataclasses.is_dataclass(metric)
359+
360+
361+
def is_attrs_instance(metric: Metric) -> TypeGuard[inspect.AttrsInstance]:
362+
"""
363+
Test if the given metric is an attr.s instance.
364+
365+
NB: `attr.has` does not provide a type guard, which we need to use other `attr` methods such as
366+
`asdict()`, so we implement one here.
367+
368+
Args:
369+
metric: An instance of a Metric.
370+
371+
Returns:
372+
True if the given metric is an instance of an attr.s-decorated Metric.
373+
False otherwise.
374+
"""
375+
return not isclass(metric) and attr.has(metric.__class__)
376+
377+
342378
def asdict(metric: Metric) -> dict[str, Any]:
343-
"""Convert a Metric instance to a dictionary."""
379+
"""
380+
Convert a Metric instance to a dictionary.
344381
345-
if dataclasses.is_dataclass(metric):
382+
Args:
383+
metric: An instance of a Metric.
384+
385+
Returns:
386+
A dictionary representation of the given metric.
387+
"""
388+
if is_dataclass_instance(metric):
346389
return dataclasses.asdict(metric)
347-
elif attr.has(metric):
390+
elif is_attrs_instance(metric):
348391
return attr.asdict(metric)
349392
else:
350393
assert False, "Unreachable"

fgpyo/util/tests/test_metric.py

+24
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
from fgpyo.util.inspect import is_attr_class
3030
from fgpyo.util.inspect import is_dataclasses_class
3131
from fgpyo.util.metric import Metric
32+
from fgpyo.util.metric import asdict
33+
from fgpyo.util.metric import is_attrs_instance
34+
from fgpyo.util.metric import is_dataclass_instance
3235

3336

3437
class EnumTest(enum.Enum):
@@ -519,3 +522,24 @@ def test_metric_columns_out_of_order(tmp_path: Path, data_and_classes: DataBuild
519522
names = list(NameMetric.read(path=path))
520523
assert len(names) == 1
521524
assert names[0] == name
525+
526+
527+
def test_is_dataclass_instance() -> None:
528+
"""Test that is_dataclass_instance works as expected."""
529+
530+
assert is_dataclass_instance(dataclasses_data_and_classes.Person(name="name", age=42))
531+
assert not is_dataclass_instance(attr_data_and_classes.Person(name="name", age=42))
532+
533+
534+
def test_is_attrs_instance() -> None:
535+
"""Test that is_attrs_instance works as expected."""
536+
537+
assert not is_attrs_instance(dataclasses_data_and_classes.Person(name="name", age=42))
538+
assert is_attrs_instance(attr_data_and_classes.Person(name="name", age=42))
539+
540+
541+
@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes))
542+
def test_asdict(data_and_classes: DataBuilder) -> None:
543+
"""Test that is_dataclass_instance works as expected."""
544+
545+
assert asdict(data_and_classes.Person(name="name", age=42)) == {"name": "name", "age": 42}

0 commit comments

Comments
 (0)