diff --git a/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py b/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py index a43ce14687..1a88f4a875 100644 --- a/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py +++ b/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py @@ -41,7 +41,7 @@ import torch - from art.utils import CLASSIFIER_NEURALNETWORK_TYPE + from art.utils import CLASSIFIER_NEURALNETWORK_TYPE, PYTORCH_OBJECT_DETECTOR_TYPE logger = logging.getLogger(__name__) @@ -72,7 +72,7 @@ class AdversarialPatchPyTorch(EvasionAttack): def __init__( self, - estimator: "CLASSIFIER_NEURALNETWORK_TYPE", + estimator: "CLASSIFIER_NEURALNETWORK_TYPE | PYTORCH_OBJECT_DETECTOR_TYPE", rotation_max: float = 22.5, scale_min: float = 0.1, scale_max: float = 1.0, @@ -91,7 +91,7 @@ def __init__( """ Create an instance of the :class:`.AdversarialPatchPyTorch`. - :param estimator: A trained estimator. + :param estimator: A trained PyTorch estimator for classification or object detection. :param rotation_max: The maximum rotation applied to random patches. The value is expected to be in the range `[0, 180]`. :param scale_min: The minimum scaling applied to random patches. The value should be in the range `[0, 1]`,