Skip to content

Add Mistral3 vision-language model support (For Flux2 Migration)#3246

Open
SpenserCai wants to merge 13 commits intohuggingface:mainfrom
SpenserCai:mistralai3_support
Open

Add Mistral3 vision-language model support (For Flux2 Migration)#3246
SpenserCai wants to merge 13 commits intohuggingface:mainfrom
SpenserCai:mistralai3_support

Conversation

@SpenserCai
Copy link
Contributor

@SpenserCai SpenserCai commented Dec 16, 2025

Summary

This PR adds support for the Mistral3 (Mistral-Small-3.x) vision-language model to candle-transformers. Mistral3 combines the Pixtral vision encoder with the Mistral language model, enabling multimodal image-text understanding.

Note: This PR is a preparatory step for the upcoming Flux2 model migration, as Flux2 shares similar multimodal architecture patterns with Mistral3.

Changes

New files in candle-transformers/src/models/mistral3/:

  • mod.rs - Module exports and documentation
  • config.rs - Mistral3Config with vision, text, and projector settings
  • model.rs - Mistral3Model and Mistral3ForConditionalGeneration
  • patch_merger.rs - PatchMerger for reducing image tokens
  • projector.rs - MultiModalProjector (RMSNorm + PatchMerger + MLP)

Modified files:

  • candle-transformers/src/models/mod.rs - Added mistral3 module export
  • candle-transformers/src/models/pixtral/vision_model.rs - Added forward_with_hidden_states() and VisionModelOutput struct
  • candle-transformers/src/models/mistral.rs - Added forward_embeds_hidden() for multimodal integration

Architecture

Mistral3ForConditionalGeneration
├── Mistral3Model
│   ├── vision_tower (Pixtral Vision Model, 24 layers)
│   ├── multi_modal_projector
│   │   ├── norm (RMSNorm)
│   │   ├── patch_merger (spatial_merge_size=2, reduces tokens by 4x)
│   │   ├── linear_1
│   │   ├── act (GELU)
│   │   └── linear_2
│   └── language_model (Mistral, 40 layers)
└── lm_head

Key Implementation Details

  1. PatchMerger: Uses reshape + permute to implement PyTorch's unfold operation (kernel_size == stride, no overlap), merging 2x2 patches into one.

  2. Image Token Replacement: Implements replace_image_tokens() as Candle equivalent of PyTorch's masked_scatter.

  3. Vision Tower Integration: Uses forward_with_hidden_states() to get batch-dimension-preserved output matching PyTorch Transformers behavior.

Supported Models

Differences from Pixtral LLaVA

Feature Pixtral LLaVA Mistral3
PatchMerger ✅ (spatial_merge_size=2)
Projector RMSNorm
Projector bias
Image token reduction 1x 4x

Usage

use candle_transformers::models::mistral3::{Mistral3Config, Mistral3ForConditionalGeneration};

let config: Mistral3Config = serde_json::from_str(&config_str)?;
let model = Mistral3ForConditionalGeneration::new(&config, vb)?;
let logits = model.forward(&input_ids, Some(&pixel_values), Some(&image_sizes), 0)?;

Verification

The implementation has been verified against PyTorch Transformers reference:

  • Vision Tower: avg_diff = 2.29e-4
  • MultiModal Projector: avg_diff = 3.61e-8
  • Full Forward Pass: Predicted token matches (token ID: 1784 "The")

Checklist

  • New model implementation follows existing patterns in candle-transformers
  • Configuration uses serde for JSON deserialization
  • Reuses existing components (Pixtral vision, Mistral language model)
  • Documentation comments included
  • Verified against PyTorch reference implementation
b59bebff62be671fdca863c2323b917f

@SpenserCai
Copy link
Contributor Author

image

mistral3 examples added!

@SpenserCai
Copy link
Contributor Author

Fixed clippy and fmt.

@ivarflakstad
Copy link
Member

Is 24B the smallest model for this schenario? It's pretty slow on metal.

@SpenserCai
Copy link
Contributor Author

Is 24B the smallest model for this schenario? It's pretty slow on metal.

Yes, the quantitative version is out, which is the smallest model.

@SpenserCai
Copy link
Contributor Author

Is 24B the smallest model for this schenario? It's pretty slow on metal.

Yes, the quantitative version is out, which is the smallest model.

@ivarflakstad Is there anything I need to change about this pr?🤔

@ivarflakstad
Copy link
Member

Not sure. Focusing on stabilizing the next release.
The example seems a bit slow to me, but it could be the model size. I haven't dug deep in the implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants