Skip to content

Commit 2a404d2

Browse files
committed
update tests
1 parent d6d670f commit 2a404d2

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

src/transformers/models/dinov3_vit/modular_dinov3_vit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from ...pytorch_utils import compile_compatible_method_lru_cache
4040
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
4141
from ...utils.backbone_utils import BackboneMixin
42-
from ...utils.generic import can_return_tuple, check_model_inputs
42+
from ...utils.generic import check_model_inputs
4343
from .configuration_dinov3_vit import DINOv3ViTConfig
4444

4545

@@ -555,4 +555,4 @@ def forward(
555555
)
556556

557557

558-
__all__ = ["DINOv3ViTModel", "DINOv3ViTPreTrainedModel", "DINOv3ViTBackbone", "DINOv3ViTForImageClassification"]
558+
__all__ = ["DINOv3ViTModel", "DINOv3ViTPreTrainedModel", "DINOv3ViTBackbone", "DINOv3ViTForImageClassification"]

tests/models/dinov3_vit/test_modeling_dinov3_vit.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ class Dinov3ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
205205
attention_mask and seq_length.
206206
"""
207207

208-
all_model_classes = (DINOv3ViTModel, DINOv3ViTBackbone DINOv3ViTForImageClassification) if is_torch_available() else ()
208+
all_model_classes = (
209+
(DINOv3ViTModel, DINOv3ViTBackbone, DINOv3ViTForImageClassification) if is_torch_available() else ()
210+
)
209211
pipeline_model_mapping = (
210212
{
211213
"image-feature-extraction": DINOv3ViTModel,
@@ -277,13 +279,14 @@ def test_model_from_pretrained(self):
277279
model_name = "facebook/dinov3-vits16-pretrain-lvd1689m"
278280
model = DINOv3ViTModel.from_pretrained(model_name)
279281
self.assertIsNotNone(model)
280-
282+
281283
@slow
282284
def test_model_for_image_classification_from_pretrained(self):
283285
model_name = "dimidagd/dinov3-vit7b16-pretrain-lvd1689m-imagenet1k-lc"
284286
model = DINOv3ViTForImageClassification.from_pretrained(model_name)
285287
self.assertIsNotNone(model)
286288

289+
287290
# We will verify our results on an image of cute cats
288291
def prepare_img():
289292
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
@@ -301,12 +304,8 @@ def default_image_processor(self):
301304
else None
302305
)
303306

304-
305307
@slow
306308
def test_inference_lc_head_imagenet(self):
307-
# tensor = torch.ones(1,3,224,224).to(model.device)
308-
# expected_output_std = 0.7570638656616211
309-
# expected_output_mean = 6.4013e-03
310309
model = DINOv3ViTModel.from_pretrained("facebook/dinov3-vits16-pretrain-lvd1689m").to(torch_device)
311310

312311
image_processor = self.default_image_processor
@@ -316,11 +315,19 @@ def test_inference_lc_head_imagenet(self):
316315
# forward pass
317316
with torch.no_grad():
318317
outputs = model(**inputs)
318+
predicted_class_idx = outputs.logits.argmax(-1).item()
319+
# 283 is cat
320+
self.assertEqual(predicted_class_idx, 283)
321+
322+
test_tensor = torch.ones(1, 3, 224, 224).to(model.device)
323+
with torch.no_grad():
324+
outputs = model(test_tensor)
319325

320-
321-
# self.assertAlmostEqual(outputs.logits.std().item(), expected_output_std, places=4)
322-
# self.assertAlmostEqual(outputs.logits.mean().item(), expected_output_mean, places=4)
326+
expected_output_std = 0.7570638656616211
327+
expected_output_mean = 6.4013e-03
323328

329+
self.assertAlmostEqual(outputs.logits.std().item(), expected_output_std, places=4)
330+
self.assertAlmostEqual(outputs.logits.mean().item(), expected_output_mean, places=4)
324331

325332
@slow
326333
def test_inference_no_head(self):

0 commit comments

Comments
 (0)