Skip to content

Commit ebb610b

Browse files
committed
Add support for dinov3 with classificaiton head
- Implements DINOv3ViTForImageClassification class - Implements unit tests - Updates docs
1 parent 3edd804 commit ebb610b

File tree

5 files changed

+156
-8
lines changed

5 files changed

+156
-8
lines changed

docs/source/en/model_doc/dinov3.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,11 @@ print("Pooled output shape:", pooled_output.shape)
169169
[[autodoc]] DINOv3ViTModel
170170
- forward
171171

172+
## DINOv3ViTForImageClassification
173+
174+
[[autodoc]] DINOv3ViTForImageClassification
175+
- forward
176+
172177
## DINOv3ConvNextModel
173178

174179
[[autodoc]] DINOv3ConvNextModel

src/transformers/models/auto/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
867867
("dinat", "DinatForImageClassification"),
868868
("dinov2", "Dinov2ForImageClassification"),
869869
("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"),
870+
("dinov3_vit", "DINOv3ViTForImageClassification"),
870871
("donut-swin", "DonutSwinForImageClassification"),
871872
(
872873
"efficientformer",

src/transformers/models/dinov3_vit/modeling_dinov3_vit.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828

2929
from ...activations import ACT2FN
3030
from ...modeling_layers import GradientCheckpointingLayer
31-
from ...modeling_outputs import BaseModelOutputWithPooling
31+
from ...modeling_outputs import BaseModelOutputWithPooling, ImageClassifierOutput
3232
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
3333
from ...processing_utils import Unpack
3434
from ...pytorch_utils import compile_compatible_method_lru_cache
3535
from ...utils import TransformersKwargs, auto_docstring
36-
from ...utils.generic import check_model_inputs
36+
from ...utils.generic import can_return_tuple, check_model_inputs
3737
from .configuration_dinov3_vit import DINOv3ViTConfig
3838

3939

@@ -530,4 +530,64 @@ def forward(
530530
)
531531

532532

533-
__all__ = ["DINOv3ViTModel", "DINOv3ViTPreTrainedModel"]
533+
@auto_docstring(
534+
custom_intro="""
535+
DINOv3ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state
536+
of the [CLS] token) e.g. for ImageNet.
537+
"""
538+
)
539+
class DINOv3ViTForImageClassification(DINOv3ViTPreTrainedModel):
540+
def __init__(self, config: DINOv3ViTConfig) -> None:
541+
super().__init__(config)
542+
543+
self.num_labels = config.num_labels
544+
self.dinov3 = DINOv3ViTModel(config)
545+
546+
# Classifier head
547+
self.classifier = (
548+
nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity()
549+
)
550+
551+
# Initialize weights and apply final processing
552+
self.post_init()
553+
554+
def get_input_embeddings(self):
555+
return self.embeddings.patch_embeddings
556+
557+
@can_return_tuple
558+
@auto_docstring
559+
def forward(
560+
self,
561+
pixel_values: Optional[torch.Tensor] = None,
562+
head_mask: Optional[torch.Tensor] = None,
563+
labels: Optional[torch.Tensor] = None,
564+
**kwargs: Unpack[TransformersKwargs],
565+
) -> ImageClassifierOutput:
566+
r"""
567+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
568+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
569+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
570+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
571+
"""
572+
outputs: BaseModelOutputWithPooling = self.dinov3(pixel_values, head_mask=head_mask, **kwargs)
573+
574+
sequence_output = outputs.last_hidden_state # batch_size, sequence_length, hidden_size
575+
cls_token = sequence_output[:, 0]
576+
patch_tokens = sequence_output[:, 1:]
577+
578+
linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
579+
logits = self.classifier(linear_input)
580+
581+
loss = None
582+
if labels is not None:
583+
loss = self.loss_function(labels, logits, self.config, **kwargs)
584+
585+
return ImageClassifierOutput(
586+
loss=loss,
587+
logits=logits,
588+
hidden_states=outputs.hidden_states,
589+
attentions=outputs.attentions,
590+
)
591+
592+
593+
__all__ = ["DINOv3ViTForImageClassification", "DINOv3ViTModel", "DINOv3ViTPreTrainedModel"]

src/transformers/models/dinov3_vit/modular_dinov3_vit.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@
3232
from transformers.models.pixtral.modeling_pixtral import PixtralAttention, rotate_half
3333

3434
from ...modeling_layers import GradientCheckpointingLayer
35-
from ...modeling_outputs import BaseModelOutputWithPooling
35+
from ...modeling_outputs import BaseModelOutputWithPooling, ImageClassifierOutput
3636
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
3737
from ...processing_utils import Unpack
3838
from ...pytorch_utils import compile_compatible_method_lru_cache
3939
from ...utils import TransformersKwargs, auto_docstring, logging
40-
from ...utils.generic import check_model_inputs
40+
from ...utils.generic import can_return_tuple, check_model_inputs
4141
from .configuration_dinov3_vit import DINOv3ViTConfig
4242

4343

@@ -425,4 +425,64 @@ def forward(
425425
)
426426

427427

428-
__all__ = ["DINOv3ViTModel", "DINOv3ViTPreTrainedModel"]
428+
@auto_docstring(
429+
custom_intro="""
430+
DINOv3ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state
431+
of the [CLS] token) e.g. for ImageNet.
432+
"""
433+
)
434+
class DINOv3ViTForImageClassification(DINOv3ViTPreTrainedModel):
435+
def __init__(self, config: DINOv3ViTConfig) -> None:
436+
super().__init__(config)
437+
438+
self.num_labels = config.num_labels
439+
self.dinov3 = DINOv3ViTModel(config)
440+
441+
# Classifier head
442+
self.classifier = (
443+
nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity()
444+
)
445+
446+
# Initialize weights and apply final processing
447+
self.post_init()
448+
449+
def get_input_embeddings(self):
450+
return self.embeddings.patch_embeddings
451+
452+
@can_return_tuple
453+
@auto_docstring
454+
def forward(
455+
self,
456+
pixel_values: Optional[torch.Tensor] = None,
457+
head_mask: Optional[torch.Tensor] = None,
458+
labels: Optional[torch.Tensor] = None,
459+
**kwargs: Unpack[TransformersKwargs],
460+
) -> ImageClassifierOutput:
461+
r"""
462+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
463+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
464+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
465+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
466+
"""
467+
outputs: BaseModelOutputWithPooling = self.dinov3(pixel_values, head_mask=head_mask, **kwargs)
468+
469+
sequence_output = outputs.last_hidden_state # batch_size, sequence_length, hidden_size
470+
cls_token = sequence_output[:, 0]
471+
patch_tokens = sequence_output[:, 1:]
472+
473+
linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
474+
logits = self.classifier(linear_input)
475+
476+
loss = None
477+
if labels is not None:
478+
loss = self.loss_function(labels, logits, self.config, **kwargs)
479+
480+
return ImageClassifierOutput(
481+
loss=loss,
482+
logits=logits,
483+
hidden_states=outputs.hidden_states,
484+
attentions=outputs.attentions,
485+
)
486+
487+
488+
__all__ = ["DINOv3ViTForImageClassification", "DINOv3ViTModel", "DINOv3ViTPreTrainedModel"]

tests/models/dinov3_vit/test_modeling_dinov3_vit.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import torch
3030
from torch import nn
3131

32-
from transformers import DINOv3ViTModel
32+
from transformers import DINOv3ViTForImageClassification, DINOv3ViTModel
3333

3434

3535
if is_vision_available():
@@ -124,6 +124,24 @@ def create_and_check_model(self, config, pixel_values, labels):
124124
(self.batch_size, self.seq_length, self.hidden_size),
125125
)
126126

127+
def create_and_check_for_image_classification(self, config, pixel_values, labels):
128+
config.num_labels = self.type_sequence_label_size
129+
model = DINOv3ViTForImageClassification(config)
130+
model.to(torch_device)
131+
model.eval()
132+
result = model(pixel_values, labels=labels)
133+
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
134+
135+
# test greyscale images
136+
config.num_channels = 1
137+
model = DINOv3ViTForImageClassification(config)
138+
model.to(torch_device)
139+
model.eval()
140+
141+
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
142+
result = model(pixel_values)
143+
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
144+
127145
def prepare_config_and_inputs_for_common(self):
128146
config_and_inputs = self.prepare_config_and_inputs()
129147
(
@@ -142,7 +160,7 @@ class Dinov3ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
142160
attention_mask and seq_length.
143161
"""
144162

145-
all_model_classes = (DINOv3ViTModel,) if is_torch_available() else ()
163+
all_model_classes = (DINOv3ViTModel, DINOv3ViTForImageClassification) if is_torch_available() else ()
146164
pipeline_model_mapping = (
147165
{
148166
"image-feature-extraction": DINOv3ViTModel,
@@ -218,6 +236,10 @@ def test_model_get_set_embeddings(self):
218236
x = model.get_output_embeddings()
219237
self.assertTrue(x is None or isinstance(x, nn.Linear))
220238

239+
def test_for_image_classification(self):
240+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
241+
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
242+
221243
def test_model(self):
222244
config_and_inputs = self.model_tester.prepare_config_and_inputs()
223245
self.model_tester.create_and_check_model(*config_and_inputs)

0 commit comments

Comments
 (0)