Porting PyTorch layer norm to Flax #2197
-
|
I'm trying to port the layer norm module from PyTorch to Flax. I transformed the state dict from PyTorch to flax yet the layer norm is still not producing the same results. I tried this on OSX 12.4 on M1, Ubuntu 20.04 and google colab and in all of them the outputs aren't equal. I created a minimal example: And created a colab link: https://colab.research.google.com/drive/1wTIbWbM9LBjzlKC14aegLjWHrkoXdapF#scrollTo=hv0tS0BJ2_EE Thanks in advance for the help. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
|
Flax uses shape inference, so you do not have to provide the normalized shape as an input to the constructor for Putting this together gives the following code: import flax
import torch
import jax.numpy as jnp
import numpy as np
torch.manual_seed(0)
torch_layernorm = torch.nn.LayerNorm(12)
flax_layernorm = flax.linen.LayerNorm()
torch_state_dict = torch_layernorm.state_dict()
torch_state_dict["scale"] = jnp.array(np.array(torch_state_dict.pop("weight")))
torch_state_dict["bias"] = jnp.array(np.array(torch_state_dict.pop("bias")))
x = torch.randn((8, 12))
x_flax = jnp.array(np.array(x))
torch_out = torch_layernorm(x)
flax_out = flax_layernorm.apply(variables={"params": torch_state_dict}, x=x_flax)
np.testing.assert_allclose(torch_out.detach().numpy(), flax_out, rtol=1e-5) |
Beta Was this translation helpful? Give feedback.
-
|
Hi, I've been trying to do the same (porting pretrained ViT based PyTorch models to Flax) and facing a similar issue. I noticed the following (below code gives error) and also this issue. Is there a way to enforce PyTorch style LayerNorm computation? import flax
import torch
import jax.numpy as jnp
import numpy as np
torch.manual_seed(0)
torch_layernorm = torch.nn.LayerNorm(768)
flax_layernorm = flax.linen.LayerNorm(use_fast_variance=False)
torch_state_dict = torch_layernorm.state_dict()
torch_state_dict["scale"] = jnp.array(np.array(torch_state_dict.pop("weight")))
torch_state_dict["bias"] = jnp.array(np.array(torch_state_dict.pop("bias")))
x = torch.randn((1, 197, 768))
x_flax = jnp.array(np.array(x))
torch_out = torch_layernorm(x)
flax_out = flax_layernorm.apply(variables={"params": torch_state_dict}, x=x_flax)
np.testing.assert_almost_equal(torch_out.detach().numpy(), flax_out, decimal=5)While the error here is low, it seems to add up and give larger errors for later outputs. Thanks a lot in advance for any help! |
Beta Was this translation helpful? Give feedback.
Flax uses shape inference, so you do not have to provide the normalized shape as an input to the constructor for
LayerNorm. Because you do provide an argument (12), this meansepsilonwill be set to 12. Also it is a good idea to use a manual seed so you get reproducible runs. Finally I'd recommend usingnp.testing.assert_allclosesince it will give you more information if your outputs don't match (absolute and relative tolerances).Putting this together gives the following code: