@@ -61,7 +61,11 @@ def setup(self):
61
61
cfg = self .config
62
62
# Linear projection layer using DenseGeneral
63
63
self .linear = linears .DenseGeneral (
64
- features = cfg .hidden_size_for_vit , dtype = cfg .dtype_mm , name = "vit_unfold_linear" , use_bias = False
64
+ features = cfg .hidden_size_for_vit ,
65
+ dtype = cfg .dtype_mm ,
66
+ name = "vit_unfold_linear" ,
67
+ use_bias = False ,
68
+ matmul_precision = cfg .matmul_precision ,
65
69
)
66
70
67
71
def __call__ (self , inputs : Array ) -> Array :
@@ -137,10 +141,18 @@ class Llama4VisionMLP(nn.Module):
137
141
def setup (self ):
138
142
cfg = self .config
139
143
self .fc1 = linears .DenseGeneral (
140
- features = cfg .intermediate_size_for_vit , dtype = cfg .dtype_mm , name = "vit_encoder_layer_mlp_fc1" , use_bias = True
144
+ features = cfg .intermediate_size_for_vit ,
145
+ dtype = cfg .dtype_mm ,
146
+ name = "vit_encoder_layer_mlp_fc1" ,
147
+ use_bias = True ,
148
+ matmul_precision = cfg .matmul_precision ,
141
149
)
142
150
self .fc2 = linears .DenseGeneral (
143
- features = cfg .hidden_size_for_vit , dtype = cfg .dtype_mm , name = "vit_encoder_layer_mlp_fc2" , use_bias = True
151
+ features = cfg .hidden_size_for_vit ,
152
+ dtype = cfg .dtype_mm ,
153
+ name = "vit_encoder_layer_mlp_fc2" ,
154
+ use_bias = True ,
155
+ matmul_precision = cfg .matmul_precision ,
144
156
)
145
157
146
158
def __call__ (self , hidden_states : Array ) -> Array :
@@ -170,10 +182,18 @@ class Llama4VisionMLP2(nn.Module):
170
182
def setup (self ):
171
183
cfg = self .config
172
184
self .fc1 = linears .DenseGeneral (
173
- features = cfg .projector_input_dim_for_vit , dtype = cfg .dtype_mm , name = "vit_pixel_shuffle_mlp_fc1" , use_bias = False
185
+ features = cfg .projector_input_dim_for_vit ,
186
+ dtype = cfg .dtype_mm ,
187
+ name = "vit_pixel_shuffle_mlp_fc1" ,
188
+ use_bias = False ,
189
+ matmul_precision = cfg .matmul_precision ,
174
190
)
175
191
self .fc2 = linears .DenseGeneral (
176
- features = cfg .projector_output_dim_for_vit , dtype = cfg .dtype_mm , name = "vit_pixel_shuffle_mlp_fc2" , use_bias = False
192
+ features = cfg .projector_output_dim_for_vit ,
193
+ dtype = cfg .dtype_mm ,
194
+ name = "vit_pixel_shuffle_mlp_fc2" ,
195
+ use_bias = False ,
196
+ matmul_precision = cfg .matmul_precision ,
177
197
)
178
198
self .dropout = nn .Dropout (rate = cfg .projector_dropout_for_vit )
179
199
@@ -252,6 +272,7 @@ def setup(self):
252
272
dtype = cfg .dtype_mm ,
253
273
name = "vit_multi_modal_projector" ,
254
274
use_bias = False ,
275
+ matmul_precision = cfg .matmul_precision ,
255
276
)
256
277
257
278
def __call__ (self , image_features : Array ) -> Array :
@@ -579,6 +600,8 @@ def __call__(
579
600
head_dim = self .config .hidden_size_for_vit // self .config .num_attention_heads_for_vit ,
580
601
max_target_length = (self .config .image_size_for_vit // self .config .patch_size_for_vit ) ** 2 + 1 ,
581
602
attention_kernel = "dot_product" ,
603
+ float32_qk_product = self .config .float32_qk_product ,
604
+ float32_logits = self .config .float32_logits ,
582
605
mesh = self .mesh ,
583
606
dropout_rate = 0 ,
584
607
name = "self_attention_vision" ,
0 commit comments