Skip to content

Commit

Permalink
Update estimator typing in AdversarialPatchPyTorch
Browse files Browse the repository at this point in the history
Signed-off-by: Beat Buesser <[email protected]>
  • Loading branch information
beat-buesser committed Jan 16, 2025
1 parent 2afa66a commit 0aa91b5
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

import torch

from art.utils import CLASSIFIER_NEURALNETWORK_TYPE
from art.utils import CLASSIFIER_NEURALNETWORK_TYPE, PYTORCH_OBJECT_DETECTOR_TYPE

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,7 +72,7 @@ class AdversarialPatchPyTorch(EvasionAttack):

def __init__(
self,
estimator: "CLASSIFIER_NEURALNETWORK_TYPE",
estimator: "CLASSIFIER_NEURALNETWORK_TYPE | PYTORCH_OBJECT_DETECTOR_TYPE",
rotation_max: float = 22.5,
scale_min: float = 0.1,
scale_max: float = 1.0,
Expand All @@ -91,7 +91,7 @@ def __init__(
"""
Create an instance of the :class:`.AdversarialPatchPyTorch`.
:param estimator: A trained estimator.
:param estimator: A trained PyTorch estimator for classification or object detection.
:param rotation_max: The maximum rotation applied to random patches. The value is expected to be in the
range `[0, 180]`.
:param scale_min: The minimum scaling applied to random patches. The value should be in the range `[0, 1]`,
Expand Down

0 comments on commit 0aa91b5

Please sign in to comment.