|
119 | 119 | import dataclasses
|
120 | 120 | from abc import ABC
|
121 | 121 | from enum import Enum
|
| 122 | +from inspect import isclass |
122 | 123 | from pathlib import Path
|
123 | 124 | from typing import Any
|
124 | 125 | from typing import Callable
|
125 | 126 | from typing import Dict
|
126 | 127 | from typing import Generic
|
127 | 128 | from typing import Iterator
|
128 | 129 | from typing import List
|
| 130 | +from typing import TypeGuard |
129 | 131 | from typing import TypeVar
|
130 | 132 |
|
131 | 133 | import attr
|
@@ -339,12 +341,53 @@ def fast_concat(*inputs: Path, output: Path) -> None:
|
339 | 341 | )
|
340 | 342 |
|
341 | 343 |
|
| 344 | +def is_dataclass_instance(metric: Metric) -> TypeGuard[inspect.DataclassInstance]: |
| 345 | + """ |
| 346 | + Test if the given metric is a dataclass instance. |
| 347 | +
|
| 348 | + NB: `dataclasses.is_dataclass` returns True for both dataclass instances and class objects, and |
| 349 | + we need to override the built-in function's `TypeGuard`. |
| 350 | +
|
| 351 | + Args: |
| 352 | + metric: An instance of a Metric. |
| 353 | +
|
| 354 | + Returns: |
| 355 | + True if the given metric is an instance of a dataclass-decorated Metric. |
| 356 | + False otherwise. |
| 357 | + """ |
| 358 | + return not isclass(metric) and dataclasses.is_dataclass(metric) |
| 359 | + |
| 360 | + |
| 361 | +def is_attrs_instance(metric: Metric) -> TypeGuard[inspect.AttrsInstance]: |
| 362 | + """ |
| 363 | + Test if the given metric is an attr.s instance. |
| 364 | +
|
| 365 | + NB: `attr.has` does not provide a type guard, which we need to use other `attr` methods such as |
| 366 | + `asdict()`, so we implement one here. |
| 367 | +
|
| 368 | + Args: |
| 369 | + metric: An instance of a Metric. |
| 370 | +
|
| 371 | + Returns: |
| 372 | + True if the given metric is an instance of an attr.s-decorated Metric. |
| 373 | + False otherwise. |
| 374 | + """ |
| 375 | + return not isclass(metric) and attr.has(metric.__class__) |
| 376 | + |
| 377 | + |
342 | 378 | def asdict(metric: Metric) -> dict[str, Any]:
|
343 |
| - """Convert a Metric instance to a dictionary.""" |
| 379 | + """ |
| 380 | + Convert a Metric instance to a dictionary. |
344 | 381 |
|
345 |
| - if dataclasses.is_dataclass(metric): |
| 382 | + Args: |
| 383 | + metric: An instance of a Metric. |
| 384 | +
|
| 385 | + Returns: |
| 386 | + A dictionary representation of the given metric. |
| 387 | + """ |
| 388 | + if is_dataclass_instance(metric): |
346 | 389 | return dataclasses.asdict(metric)
|
347 |
| - elif attr.has(metric): |
| 390 | + elif is_attrs_instance(metric): |
348 | 391 | return attr.asdict(metric)
|
349 | 392 | else:
|
350 | 393 | assert False, "Unreachable"
|
0 commit comments