Skip to content

Commit 0183a95

Browse files
author
Flax Authors
committed
Merge pull request #4999 from samanklesaria:issues/4997
PiperOrigin-RevId: 824643496
2 parents 66508ee + 32a64d2 commit 0183a95

File tree

2 files changed

+298
-0
lines changed

2 files changed

+298
-0
lines changed
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
Convert PyTorch models to Flax
2+
==============================
3+
4+
.. testsetup::
5+
6+
import numpy as np
7+
import jax
8+
from jax import random, numpy as jnp
9+
from flax import nnx
10+
11+
import torch
12+
13+
We will show how to convert PyTorch models to Flax. We will cover convolutions, fc layers, batch norm, and average pooling.
14+
15+
16+
FC Layers
17+
--------------------------------
18+
19+
Let's start with fc layers. The only thing to be aware of here is that the PyTorch kernel has shape [outC, inC]
20+
and the Flax kernel has shape [inC, outC]. Transposing the kernel will do the trick.
21+
22+
.. testcode::
23+
24+
t_fc = torch.nn.Linear(in_features=3, out_features=4)
25+
26+
kernel = t_fc.weight.detach().cpu().numpy()
27+
bias = t_fc.bias.detach().cpu().numpy()
28+
29+
# [outC, inC] -> [inC, outC]
30+
kernel = jnp.transpose(kernel, (1, 0))
31+
32+
key = random.key(0)
33+
x = random.normal(key, (1, 3))
34+
35+
j_fc = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
36+
j_fc.kernel.value = kernel
37+
j_fc.bias.value = jnp.array(bias)
38+
j_out = j_fc(x)
39+
40+
t_x = torch.from_numpy(np.array(x))
41+
t_out = t_fc(t_x)
42+
t_out = t_out.detach().cpu().numpy()
43+
44+
np.testing.assert_almost_equal(j_out, t_out, decimal=6)
45+
46+
47+
Convolutions
48+
--------------------------------
49+
50+
Let's now look at 2D convolutions. PyTorch uses the NCHW format and Flax uses NHWC.
51+
Consequently, the kernels will have different shapes. The kernel in PyTorch has shape [outC, inC, kH, kW]
52+
and the Flax kernel has shape [kH, kW, inC, outC]. Transposing the kernel will do the trick.
53+
54+
.. testcode::
55+
56+
t_conv = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding='valid')
57+
58+
kernel = t_conv.weight.detach().cpu().numpy()
59+
bias = t_conv.bias.detach().cpu().numpy()
60+
61+
# [outC, inC, kH, kW] -> [kH, kW, inC, outC]
62+
kernel = jnp.transpose(kernel, (2, 3, 1, 0))
63+
64+
key = random.key(0)
65+
x = random.normal(key, (1, 6, 6, 3))
66+
67+
j_conv = nnx.Conv(3, 4, kernel_size=(2, 2), padding='valid', rngs=nnx.Rngs(0))
68+
j_conv.kernel.value = kernel
69+
j_conv.bias.value = bias
70+
j_out = j_conv(x)
71+
72+
# [N, H, W, C] -> [N, C, H, W]
73+
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
74+
t_out = t_conv(t_x)
75+
# [N, C, H, W] -> [N, H, W, C]
76+
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))
77+
78+
np.testing.assert_almost_equal(j_out, t_out, decimal=6)
79+
80+
81+
82+
Convolutions and FC Layers
83+
--------------------------------
84+
85+
We have to be careful, when we have a model that uses convolutions followed by fc layers (ResNet, VGG, etc).
86+
In PyTorch, the activations will have shape [N, C, H, W] after the convolutions and are then
87+
reshaped to [N, C * H * W] before being fed to the fc layers.
88+
When we port our weights from PyTorch to Flax, the activations after the convolutions will be of shape [N, H, W, C] in Flax.
89+
Before we reshape the activations for the fc layers, we have to transpose them to [N, C, H, W].
90+
91+
Consider this PyTorch model:
92+
93+
.. testcode::
94+
95+
class TModel(torch.nn.Module):
96+
97+
def __init__(self):
98+
super(TModel, self).__init__()
99+
self.conv = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding='valid')
100+
self.fc = torch.nn.Linear(in_features=100, out_features=2)
101+
102+
def forward(self, x):
103+
x = self.conv(x)
104+
x = x.reshape(x.shape[0], -1)
105+
x = self.fc(x)
106+
return x
107+
108+
109+
t_model = TModel()
110+
111+
112+
113+
Now, if you want to use the weights from this model in Flax, the corresponding Flax model has to look like this:
114+
115+
116+
.. testcode::
117+
118+
class JModel(nnx.Module):
119+
def __init__(self, rngs):
120+
self.conv = nnx.Conv(3, 4, kernel_size=(2, 2), padding='valid', rngs=rngs)
121+
self.linear = nnx.Linear(100, 2, rngs=rngs)
122+
123+
def __call__(self, x):
124+
x = self.conv(x)
125+
# [N, H, W, C] -> [N, C, H, W]
126+
x = jnp.transpose(x, (0, 3, 1, 2))
127+
x = jnp.reshape(x, (x.shape[0], -1))
128+
x = self.linear(x)
129+
return x
130+
131+
j_model = JModel(nnx.Rngs(0))
132+
133+
134+
135+
The model looks very similar to the PyTorch model, except that we included a transpose operation before
136+
reshaping our activations for the fc layer.
137+
We can omit the transpose operation if we apply pooling before reshaping such that the spatial dimensions are 1x1.
138+
139+
Other than the transpose operation before reshaping, we can convert the weights the same way as we did before:
140+
141+
142+
.. testcode::
143+
144+
conv_kernel = t_model.state_dict()['conv.weight'].detach().cpu().numpy()
145+
conv_bias = t_model.state_dict()['conv.bias'].detach().cpu().numpy()
146+
fc_kernel = t_model.state_dict()['fc.weight'].detach().cpu().numpy()
147+
fc_bias = t_model.state_dict()['fc.bias'].detach().cpu().numpy()
148+
149+
# [outC, inC, kH, kW] -> [kH, kW, inC, outC]
150+
conv_kernel = jnp.transpose(conv_kernel, (2, 3, 1, 0))
151+
152+
# [outC, inC] -> [inC, outC]
153+
fc_kernel = jnp.transpose(fc_kernel, (1, 0))
154+
155+
j_model.conv.kernel.value = conv_kernel
156+
j_model.conv.bias.value = conv_bias
157+
j_model.linear.kernel.value = fc_kernel
158+
j_model.linear.bias.value = fc_bias
159+
160+
key = random.key(0)
161+
x = random.normal(key, (1, 6, 6, 3))
162+
j_out = j_model(x)
163+
164+
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
165+
t_out = t_model(t_x)
166+
t_out = t_out.detach().cpu().numpy()
167+
168+
np.testing.assert_almost_equal(j_out, t_out, decimal=6)
169+
170+
171+
172+
Batch Norm
173+
--------------------------------
174+
175+
``torch.nn.BatchNorm2d`` uses ``0.1`` as the default value for the ``momentum`` parameter while
176+
|nnx.BatchNorm|_ uses ``0.9``. However, this corresponds to the same computation, because PyTorch multiplies
177+
the estimated statistic with ``(1 − momentum)`` and the new observed value with ``momentum``,
178+
while Flax multiplies the estimated statistic with ``momentum`` and the new observed value with ``(1 − momentum)``.
179+
180+
.. |nnx.BatchNorm| replace:: ``nnx.BatchNorm``
181+
.. _nnx.BatchNorm: https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm
182+
183+
.. testcode::
184+
185+
t_bn = torch.nn.BatchNorm2d(num_features=3, momentum=0.1)
186+
t_bn.eval()
187+
188+
key = random.key(0)
189+
x = random.normal(key, (1, 6, 6, 3))
190+
191+
j_bn = nnx.BatchNorm(num_features=3, momentum=0.9, use_running_average=True, rngs=nnx.Rngs(0))
192+
193+
j_out = j_bn(x)
194+
195+
# [N, H, W, C] -> [N, C, H, W]
196+
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
197+
t_out = t_bn(t_x)
198+
# [N, C, H, W] -> [N, H, W, C]
199+
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))
200+
201+
np.testing.assert_almost_equal(j_out, t_out, decimal=6)
202+
203+
204+
Average Pooling
205+
--------------------------------
206+
207+
``torch.nn.AvgPool2d`` and |nnx.avg_pool()|_ are compatible when using default parameters.
208+
However, ``torch.nn.AvgPool2d`` has a parameter ``count_include_pad``. When ``count_include_pad=False``,
209+
the zero-padding will not be considered for the average calculation. There does not exist a similar
210+
parameter for |nnx.avg_pool()|_. However, we can easily implement a wrapper around the pooling
211+
operation. ``nnx.pool()`` is the core function behind |nnx.avg_pool()|_ and |nnx.max_pool()|_.
212+
213+
.. |nnx.avg_pool()| replace:: ``nnx.avg_pool()``
214+
.. _nnx.avg_pool(): https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.avg_pool
215+
216+
.. |nnx.max_pool()| replace:: ``nnx.max_pool()``
217+
.. _nnx.max_pool(): https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.max_pool
218+
219+
220+
.. testcode::
221+
222+
def avg_pool(inputs, window_shape, strides=None, padding='VALID'):
223+
"""
224+
Pools the input by taking the average over a window.
225+
In comparison to nnx.avg_pool(), this pooling operation does not
226+
consider the padded zero's for the average computation.
227+
"""
228+
assert len(window_shape) == 2
229+
230+
y = nnx.pool(inputs, 0., jax.lax.add, window_shape, strides, padding)
231+
counts = nnx.pool(jnp.ones_like(inputs), 0., jax.lax.add, window_shape, strides, padding)
232+
y = y / counts
233+
return y
234+
235+
236+
key = random.key(0)
237+
x = random.normal(key, (1, 6, 6, 3))
238+
239+
j_out = avg_pool(x, window_shape=(2, 2), strides=(1, 1), padding=((1, 1), (1, 1)))
240+
t_pool = torch.nn.AvgPool2d(kernel_size=2, stride=1, padding=1, count_include_pad=False)
241+
242+
# [N, H, W, C] -> [N, C, H, W]
243+
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
244+
t_out = t_pool(t_x)
245+
# [N, C, H, W] -> [N, H, W, C]
246+
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))
247+
248+
np.testing.assert_almost_equal(j_out, t_out, decimal=6)
249+
250+
251+
252+
Transposed Convolutions
253+
--------------------------------
254+
255+
``torch.nn.ConvTranspose2d`` and |nnx.ConvTranspose|_ are not compatible.
256+
|nnx.ConvTranspose|_ is a wrapper around |jax.lax.conv_transpose|_ which computes a fractionally strided convolution,
257+
while ``torch.nn.ConvTranspose2d`` computes a gradient based transposed convolution. Currently, there is no
258+
implementation of a gradient based transposed convolution is ``Jax``. However, there is a pending `pull request`_
259+
that contains an implementation.
260+
261+
To load ``torch.nn.ConvTranspose2d`` parameters into Flax, we need to use the ``transpose_kernel`` arg in Flax's
262+
``nnx.ConvTranspose`` layer.
263+
264+
.. testcode::
265+
266+
t_conv = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=0)
267+
268+
kernel = t_conv.weight.detach().cpu().numpy()
269+
bias = t_conv.bias.detach().cpu().numpy()
270+
271+
# [inC, outC, kH, kW] -> [kH, kW, outC, inC]
272+
kernel = jnp.transpose(kernel, (2, 3, 1, 0))
273+
274+
key = random.key(0)
275+
x = random.normal(key, (1, 6, 6, 3))
276+
277+
# ConvTranspose expects the kernel to be [kH, kW, inC, outC],
278+
# but with `transpose_kernel=True`, it expects [kH, kW, outC, inC] instead
279+
j_conv = nnx.ConvTranspose(3, 4, kernel_size=(2, 2), padding='VALID', transpose_kernel=True, rngs=nnx.Rngs(0))
280+
j_conv.kernel.value = kernel
281+
j_conv.bias.value = bias
282+
j_out = j_conv(x)
283+
284+
# [N, H, W, C] -> [N, C, H, W]
285+
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
286+
t_out = t_conv(t_x)
287+
# [N, C, H, W] -> [N, H, W, C]
288+
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))
289+
np.testing.assert_almost_equal(j_out, t_out, decimal=6)
290+
291+
.. _`pull request`: https://github.com/jax-ml/jax/pull/5772
292+
293+
.. |nnx.ConvTranspose| replace:: ``nnx.ConvTranspose``
294+
.. _nnx.ConvTranspose: https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.ConvTranspose
295+
296+
.. |jax.lax.conv_transpose| replace:: ``jax.lax.conv_transpose``
297+
.. _jax.lax.conv_transpose: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_transpose.html

docs_nnx/migrating/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Migrating
44
.. toctree::
55
:maxdepth: 2
66

7+
convert_pytorch_to_flax
78
nnx_010_to_nnx_011
89
linen_to_nnx
910
haiku_to_flax

0 commit comments

Comments
 (0)