Skip to content

Commit f42a62e

Browse files
committed
Revert "remove get_embeddings test"
This reverts commit 416d0b2.
1 parent 26f5561 commit f42a62e

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

tests/models/dinov3_vit/test_modeling_dinov3_vit.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
if is_torch_available():
2929
import torch
30+
from torch import nn
3031

3132
from transformers import DINOv3ViTBackbone, DINOv3ViTForImageClassification, DINOv3ViTModel
3233

@@ -250,6 +251,15 @@ def test_training_gradient_checkpointing_use_reentrant(self):
250251
def test_training_gradient_checkpointing_use_reentrant_false(self):
251252
pass
252253

254+
def test_model_get_set_embeddings(self):
255+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
256+
257+
for model_class in self.all_model_classes:
258+
model = model_class(config)
259+
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
260+
x = model.get_output_embeddings()
261+
self.assertTrue(x is None or isinstance(x, nn.Linear))
262+
253263
def test_for_image_classification(self):
254264
config_and_inputs = self.model_tester.prepare_config_and_inputs()
255265
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)

0 commit comments

Comments
 (0)