Skip to content

Commit dfe2bfa

Browse files
committed
Reduce numerical difference in llama4 vision encoder
1 parent 5ec9248 commit dfe2bfa

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

MaxText/layers/llama4.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,11 @@ def setup(self):
6161
cfg = self.config
6262
# Linear projection layer using DenseGeneral
6363
self.linear = linears.DenseGeneral(
64-
features=cfg.hidden_size_for_vit, dtype=cfg.dtype_mm, name="vit_unfold_linear", use_bias=False
64+
features=cfg.hidden_size_for_vit,
65+
dtype=cfg.dtype_mm,
66+
name="vit_unfold_linear",
67+
use_bias=False,
68+
matmul_precision=cfg.matmul_precision,
6569
)
6670

6771
def __call__(self, inputs: Array) -> Array:
@@ -137,10 +141,18 @@ class Llama4VisionMLP(nn.Module):
137141
def setup(self):
138142
cfg = self.config
139143
self.fc1 = linears.DenseGeneral(
140-
features=cfg.intermediate_size_for_vit, dtype=cfg.dtype_mm, name="vit_encoder_layer_mlp_fc1", use_bias=True
144+
features=cfg.intermediate_size_for_vit,
145+
dtype=cfg.dtype_mm,
146+
name="vit_encoder_layer_mlp_fc1",
147+
use_bias=True,
148+
matmul_precision=cfg.matmul_precision,
141149
)
142150
self.fc2 = linears.DenseGeneral(
143-
features=cfg.hidden_size_for_vit, dtype=cfg.dtype_mm, name="vit_encoder_layer_mlp_fc2", use_bias=True
151+
features=cfg.hidden_size_for_vit,
152+
dtype=cfg.dtype_mm,
153+
name="vit_encoder_layer_mlp_fc2",
154+
use_bias=True,
155+
matmul_precision=cfg.matmul_precision,
144156
)
145157

146158
def __call__(self, hidden_states: Array) -> Array:
@@ -170,10 +182,18 @@ class Llama4VisionMLP2(nn.Module):
170182
def setup(self):
171183
cfg = self.config
172184
self.fc1 = linears.DenseGeneral(
173-
features=cfg.projector_input_dim_for_vit, dtype=cfg.dtype_mm, name="vit_pixel_shuffle_mlp_fc1", use_bias=False
185+
features=cfg.projector_input_dim_for_vit,
186+
dtype=cfg.dtype_mm,
187+
name="vit_pixel_shuffle_mlp_fc1",
188+
use_bias=False,
189+
matmul_precision=cfg.matmul_precision,
174190
)
175191
self.fc2 = linears.DenseGeneral(
176-
features=cfg.projector_output_dim_for_vit, dtype=cfg.dtype_mm, name="vit_pixel_shuffle_mlp_fc2", use_bias=False
192+
features=cfg.projector_output_dim_for_vit,
193+
dtype=cfg.dtype_mm,
194+
name="vit_pixel_shuffle_mlp_fc2",
195+
use_bias=False,
196+
matmul_precision=cfg.matmul_precision,
177197
)
178198
self.dropout = nn.Dropout(rate=cfg.projector_dropout_for_vit)
179199

@@ -252,6 +272,7 @@ def setup(self):
252272
dtype=cfg.dtype_mm,
253273
name="vit_multi_modal_projector",
254274
use_bias=False,
275+
matmul_precision=cfg.matmul_precision,
255276
)
256277

257278
def __call__(self, image_features: Array) -> Array:
@@ -579,6 +600,8 @@ def __call__(
579600
head_dim=self.config.hidden_size_for_vit // self.config.num_attention_heads_for_vit,
580601
max_target_length=(self.config.image_size_for_vit // self.config.patch_size_for_vit) ** 2 + 1,
581602
attention_kernel="dot_product",
603+
float32_qk_product=self.config.float32_qk_product,
604+
float32_logits=self.config.float32_logits,
582605
mesh=self.mesh,
583606
dropout_rate=0,
584607
name="self_attention_vision",

MaxText/tests/check_llama4_layers.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -842,13 +842,15 @@ class Config(NamedTuple):
842842
attention_dropout: int = 0
843843

844844
config_arguments = {
845-
"per_device_batch_size": 4.0,
846845
"run_name": "test",
847846
"enable_checkpointing": False,
848847
"model_name": "llama4-17b-16e",
849848
"scan_layers": False,
850-
"num_hidden_layers_for_vit": 6,
849+
"num_hidden_layers_for_vit": 34,
851850
"dtype": "float32",
851+
"matmul_precision": "float32",
852+
"float32_qk_product": True,
853+
"float32_logits": True,
852854
}
853855

854856
def setUp(self):
@@ -885,6 +887,7 @@ def test_vision_encoder(self):
885887
# Create test input using config dimensions
886888
batch_size = 4
887889
inputs = jnp.ones((batch_size, self.seq_len_for_vit, self.cfg.hidden_size_for_vit), dtype=jnp.float32)
890+
inputs /= 10
888891

889892
# Initialize JAX parameters
890893
params = jax_model.init(self.rng, inputs, deterministic=True)
@@ -909,7 +912,7 @@ def test_vision_encoder(self):
909912
jax_outputs = jax_model.apply(params, inputs, deterministic=True)
910913

911914
# Compare outputs
912-
np.testing.assert_allclose(jax_outputs, to_jax(pt_outputs), rtol=1e-3, atol=0.05)
915+
np.testing.assert_allclose(jax_outputs, to_jax(pt_outputs), rtol=0.01, atol=0.05)
913916

914917

915918
if __name__ == "__main__":

0 commit comments

Comments
 (0)