From 309b817344176539f6b0196dc798cda9880cfefc Mon Sep 17 00:00:00 2001 From: Paul Elliott Date: Wed, 21 Aug 2024 14:30:40 -0400 Subject: [PATCH] fix(nrtk_transforms): avoid runtime error from type hints --- src/nrtk_explorer/library/nrtk_transforms.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/nrtk_explorer/library/nrtk_transforms.py b/src/nrtk_explorer/library/nrtk_transforms.py index a3ccea85..036580e4 100644 --- a/src/nrtk_explorer/library/nrtk_transforms.py +++ b/src/nrtk_explorer/library/nrtk_transforms.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, TYPE_CHECKING import numpy as np import logging @@ -15,22 +15,27 @@ from pybsm.otf import darkCurrentFromDensity from nrtk.impls.perturb_image.generic.cv2.blur import GaussianBlurPerturber from nrtk.impls.perturb_image.pybsm.perturber import PybsmPerturber, PybsmSensor, PybsmScenario - - GaussianBlurPerturberType = Union[GaussianBlurPerturber, None] - PybsmPerturberType = Union[PybsmPerturber, None] except ImportError: logger.info("Disabling NRTK transforms due to missing library/failing imports") ENABLED_NRTK_TRANSFORMS = False + +if TYPE_CHECKING: + GaussianBlurPerturberType = GaussianBlurPerturber + PybsmPerturberType = PybsmPerturber +else: GaussianBlurPerturberType = None PybsmPerturberType = None +GaussianBlurPerturberArg = Optional[GaussianBlurPerturberType] +PybsmPerturberArg = Optional[PybsmPerturberType] + def nrtk_transforms_available(): return ENABLED_NRTK_TRANSFORMS class NrtkGaussianBlurTransform(ImageTransform): - def __init__(self, perturber: GaussianBlurPerturberType = None): + def __init__(self, perturber: GaussianBlurPerturberArg = None): if perturber is None: perturber = GaussianBlurPerturber() @@ -160,7 +165,7 @@ def createSampleSensorAndScenario(): class NrtkPybsmTransform(ImageTransform): - def __init__(self, perturber: PybsmPerturberType = None): + def __init__(self, perturber: PybsmPerturberArg = None): if perturber is None: sensor, scenario = createSampleSensorAndScenario() perturber = PybsmPerturber(sensor=sensor, scenario=scenario)