|
| 1 | +Convert PyTorch models to Flax |
| 2 | +============================== |
| 3 | + |
| 4 | +.. testsetup:: |
| 5 | + |
| 6 | + import numpy as np |
| 7 | + import jax |
| 8 | + from jax import random, numpy as jnp |
| 9 | + from flax import nnx |
| 10 | + |
| 11 | + import torch |
| 12 | + |
| 13 | +We will show how to convert PyTorch models to Flax. We will cover convolutions, fc layers, batch norm, and average pooling. |
| 14 | + |
| 15 | + |
| 16 | +FC Layers |
| 17 | +-------------------------------- |
| 18 | + |
| 19 | +Let's start with fc layers. The only thing to be aware of here is that the PyTorch kernel has shape [outC, inC] |
| 20 | +and the Flax kernel has shape [inC, outC]. Transposing the kernel will do the trick. |
| 21 | + |
| 22 | +.. testcode:: |
| 23 | + |
| 24 | + t_fc = torch.nn.Linear(in_features=3, out_features=4) |
| 25 | + |
| 26 | + kernel = t_fc.weight.detach().cpu().numpy() |
| 27 | + bias = t_fc.bias.detach().cpu().numpy() |
| 28 | + |
| 29 | + # [outC, inC] -> [inC, outC] |
| 30 | + kernel = jnp.transpose(kernel, (1, 0)) |
| 31 | + |
| 32 | + key = random.key(0) |
| 33 | + x = random.normal(key, (1, 3)) |
| 34 | + |
| 35 | + j_fc = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0)) |
| 36 | + j_fc.kernel.value = kernel |
| 37 | + j_fc.bias.value = jnp.array(bias) |
| 38 | + j_out = j_fc(x) |
| 39 | + |
| 40 | + t_x = torch.from_numpy(np.array(x)) |
| 41 | + t_out = t_fc(t_x) |
| 42 | + t_out = t_out.detach().cpu().numpy() |
| 43 | + |
| 44 | + np.testing.assert_almost_equal(j_out, t_out, decimal=6) |
| 45 | + |
| 46 | + |
| 47 | +Convolutions |
| 48 | +-------------------------------- |
| 49 | + |
| 50 | +Let's now look at 2D convolutions. PyTorch uses the NCHW format and Flax uses NHWC. |
| 51 | +Consequently, the kernels will have different shapes. The kernel in PyTorch has shape [outC, inC, kH, kW] |
| 52 | +and the Flax kernel has shape [kH, kW, inC, outC]. Transposing the kernel will do the trick. |
| 53 | + |
| 54 | +.. testcode:: |
| 55 | + |
| 56 | + t_conv = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding='valid') |
| 57 | + |
| 58 | + kernel = t_conv.weight.detach().cpu().numpy() |
| 59 | + bias = t_conv.bias.detach().cpu().numpy() |
| 60 | + |
| 61 | + # [outC, inC, kH, kW] -> [kH, kW, inC, outC] |
| 62 | + kernel = jnp.transpose(kernel, (2, 3, 1, 0)) |
| 63 | + |
| 64 | + key = random.key(0) |
| 65 | + x = random.normal(key, (1, 6, 6, 3)) |
| 66 | + |
| 67 | + j_conv = nnx.Conv(3, 4, kernel_size=(2, 2), padding='valid', rngs=nnx.Rngs(0)) |
| 68 | + j_conv.kernel.value = kernel |
| 69 | + j_conv.bias.value = bias |
| 70 | + j_out = j_conv(x) |
| 71 | + |
| 72 | + # [N, H, W, C] -> [N, C, H, W] |
| 73 | + t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) |
| 74 | + t_out = t_conv(t_x) |
| 75 | + # [N, C, H, W] -> [N, H, W, C] |
| 76 | + t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) |
| 77 | + |
| 78 | + np.testing.assert_almost_equal(j_out, t_out, decimal=6) |
| 79 | + |
| 80 | + |
| 81 | + |
| 82 | +Convolutions and FC Layers |
| 83 | +-------------------------------- |
| 84 | + |
| 85 | +We have to be careful, when we have a model that uses convolutions followed by fc layers (ResNet, VGG, etc). |
| 86 | +In PyTorch, the activations will have shape [N, C, H, W] after the convolutions and are then |
| 87 | +reshaped to [N, C * H * W] before being fed to the fc layers. |
| 88 | +When we port our weights from PyTorch to Flax, the activations after the convolutions will be of shape [N, H, W, C] in Flax. |
| 89 | +Before we reshape the activations for the fc layers, we have to transpose them to [N, C, H, W]. |
| 90 | + |
| 91 | +Consider this PyTorch model: |
| 92 | + |
| 93 | +.. testcode:: |
| 94 | + |
| 95 | + class TModel(torch.nn.Module): |
| 96 | + |
| 97 | + def __init__(self): |
| 98 | + super(TModel, self).__init__() |
| 99 | + self.conv = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding='valid') |
| 100 | + self.fc = torch.nn.Linear(in_features=100, out_features=2) |
| 101 | + |
| 102 | + def forward(self, x): |
| 103 | + x = self.conv(x) |
| 104 | + x = x.reshape(x.shape[0], -1) |
| 105 | + x = self.fc(x) |
| 106 | + return x |
| 107 | + |
| 108 | + |
| 109 | + t_model = TModel() |
| 110 | + |
| 111 | + |
| 112 | + |
| 113 | +Now, if you want to use the weights from this model in Flax, the corresponding Flax model has to look like this: |
| 114 | + |
| 115 | + |
| 116 | +.. testcode:: |
| 117 | + |
| 118 | + class JModel(nnx.Module): |
| 119 | + def __init__(self, rngs): |
| 120 | + self.conv = nnx.Conv(3, 4, kernel_size=(2, 2), padding='valid', rngs=rngs) |
| 121 | + self.linear = nnx.Linear(100, 2, rngs=rngs) |
| 122 | + |
| 123 | + def __call__(self, x): |
| 124 | + x = self.conv(x) |
| 125 | + # [N, H, W, C] -> [N, C, H, W] |
| 126 | + x = jnp.transpose(x, (0, 3, 1, 2)) |
| 127 | + x = jnp.reshape(x, (x.shape[0], -1)) |
| 128 | + x = self.linear(x) |
| 129 | + return x |
| 130 | + |
| 131 | + j_model = JModel(nnx.Rngs(0)) |
| 132 | + |
| 133 | + |
| 134 | + |
| 135 | +The model looks very similar to the PyTorch model, except that we included a transpose operation before |
| 136 | +reshaping our activations for the fc layer. |
| 137 | +We can omit the transpose operation if we apply pooling before reshaping such that the spatial dimensions are 1x1. |
| 138 | + |
| 139 | +Other than the transpose operation before reshaping, we can convert the weights the same way as we did before: |
| 140 | + |
| 141 | + |
| 142 | +.. testcode:: |
| 143 | + |
| 144 | + conv_kernel = t_model.state_dict()['conv.weight'].detach().cpu().numpy() |
| 145 | + conv_bias = t_model.state_dict()['conv.bias'].detach().cpu().numpy() |
| 146 | + fc_kernel = t_model.state_dict()['fc.weight'].detach().cpu().numpy() |
| 147 | + fc_bias = t_model.state_dict()['fc.bias'].detach().cpu().numpy() |
| 148 | + |
| 149 | + # [outC, inC, kH, kW] -> [kH, kW, inC, outC] |
| 150 | + conv_kernel = jnp.transpose(conv_kernel, (2, 3, 1, 0)) |
| 151 | + |
| 152 | + # [outC, inC] -> [inC, outC] |
| 153 | + fc_kernel = jnp.transpose(fc_kernel, (1, 0)) |
| 154 | + |
| 155 | + j_model.conv.kernel.value = conv_kernel |
| 156 | + j_model.conv.bias.value = conv_bias |
| 157 | + j_model.linear.kernel.value = fc_kernel |
| 158 | + j_model.linear.bias.value = fc_bias |
| 159 | + |
| 160 | + key = random.key(0) |
| 161 | + x = random.normal(key, (1, 6, 6, 3)) |
| 162 | + j_out = j_model(x) |
| 163 | + |
| 164 | + t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) |
| 165 | + t_out = t_model(t_x) |
| 166 | + t_out = t_out.detach().cpu().numpy() |
| 167 | + |
| 168 | + np.testing.assert_almost_equal(j_out, t_out, decimal=6) |
| 169 | + |
| 170 | + |
| 171 | + |
| 172 | +Batch Norm |
| 173 | +-------------------------------- |
| 174 | + |
| 175 | +``torch.nn.BatchNorm2d`` uses ``0.1`` as the default value for the ``momentum`` parameter while |
| 176 | +|nnx.BatchNorm|_ uses ``0.9``. However, this corresponds to the same computation, because PyTorch multiplies |
| 177 | +the estimated statistic with ``(1 − momentum)`` and the new observed value with ``momentum``, |
| 178 | +while Flax multiplies the estimated statistic with ``momentum`` and the new observed value with ``(1 − momentum)``. |
| 179 | + |
| 180 | +.. |nnx.BatchNorm| replace:: ``nnx.BatchNorm`` |
| 181 | +.. _nnx.BatchNorm: https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm |
| 182 | + |
| 183 | +.. testcode:: |
| 184 | + |
| 185 | + t_bn = torch.nn.BatchNorm2d(num_features=3, momentum=0.1) |
| 186 | + t_bn.eval() |
| 187 | + |
| 188 | + key = random.key(0) |
| 189 | + x = random.normal(key, (1, 6, 6, 3)) |
| 190 | + |
| 191 | + j_bn = nnx.BatchNorm(num_features=3, momentum=0.9, use_running_average=True, rngs=nnx.Rngs(0)) |
| 192 | + |
| 193 | + j_out = j_bn(x) |
| 194 | + |
| 195 | + # [N, H, W, C] -> [N, C, H, W] |
| 196 | + t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) |
| 197 | + t_out = t_bn(t_x) |
| 198 | + # [N, C, H, W] -> [N, H, W, C] |
| 199 | + t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) |
| 200 | + |
| 201 | + np.testing.assert_almost_equal(j_out, t_out, decimal=6) |
| 202 | + |
| 203 | + |
| 204 | +Average Pooling |
| 205 | +-------------------------------- |
| 206 | + |
| 207 | +``torch.nn.AvgPool2d`` and |nnx.avg_pool()|_ are compatible when using default parameters. |
| 208 | +However, ``torch.nn.AvgPool2d`` has a parameter ``count_include_pad``. When ``count_include_pad=False``, |
| 209 | +the zero-padding will not be considered for the average calculation. There does not exist a similar |
| 210 | +parameter for |nnx.avg_pool()|_. However, we can easily implement a wrapper around the pooling |
| 211 | +operation. ``nnx.pool()`` is the core function behind |nnx.avg_pool()|_ and |nnx.max_pool()|_. |
| 212 | + |
| 213 | +.. |nnx.avg_pool()| replace:: ``nnx.avg_pool()`` |
| 214 | +.. _nnx.avg_pool(): https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.avg_pool |
| 215 | + |
| 216 | +.. |nnx.max_pool()| replace:: ``nnx.max_pool()`` |
| 217 | +.. _nnx.max_pool(): https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.max_pool |
| 218 | + |
| 219 | + |
| 220 | +.. testcode:: |
| 221 | + |
| 222 | + def avg_pool(inputs, window_shape, strides=None, padding='VALID'): |
| 223 | + """ |
| 224 | + Pools the input by taking the average over a window. |
| 225 | + In comparison to nnx.avg_pool(), this pooling operation does not |
| 226 | + consider the padded zero's for the average computation. |
| 227 | + """ |
| 228 | + assert len(window_shape) == 2 |
| 229 | + |
| 230 | + y = nnx.pool(inputs, 0., jax.lax.add, window_shape, strides, padding) |
| 231 | + counts = nnx.pool(jnp.ones_like(inputs), 0., jax.lax.add, window_shape, strides, padding) |
| 232 | + y = y / counts |
| 233 | + return y |
| 234 | + |
| 235 | + |
| 236 | + key = random.key(0) |
| 237 | + x = random.normal(key, (1, 6, 6, 3)) |
| 238 | + |
| 239 | + j_out = avg_pool(x, window_shape=(2, 2), strides=(1, 1), padding=((1, 1), (1, 1))) |
| 240 | + t_pool = torch.nn.AvgPool2d(kernel_size=2, stride=1, padding=1, count_include_pad=False) |
| 241 | + |
| 242 | + # [N, H, W, C] -> [N, C, H, W] |
| 243 | + t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) |
| 244 | + t_out = t_pool(t_x) |
| 245 | + # [N, C, H, W] -> [N, H, W, C] |
| 246 | + t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) |
| 247 | + |
| 248 | + np.testing.assert_almost_equal(j_out, t_out, decimal=6) |
| 249 | + |
| 250 | + |
| 251 | + |
| 252 | +Transposed Convolutions |
| 253 | +-------------------------------- |
| 254 | + |
| 255 | +``torch.nn.ConvTranspose2d`` and |nnx.ConvTranspose|_ are not compatible. |
| 256 | +|nnx.ConvTranspose|_ is a wrapper around |jax.lax.conv_transpose|_ which computes a fractionally strided convolution, |
| 257 | +while ``torch.nn.ConvTranspose2d`` computes a gradient based transposed convolution. Currently, there is no |
| 258 | +implementation of a gradient based transposed convolution is ``Jax``. However, there is a pending `pull request`_ |
| 259 | +that contains an implementation. |
| 260 | + |
| 261 | +To load ``torch.nn.ConvTranspose2d`` parameters into Flax, we need to use the ``transpose_kernel`` arg in Flax's |
| 262 | +``nnx.ConvTranspose`` layer. |
| 263 | + |
| 264 | +.. testcode:: |
| 265 | + |
| 266 | + t_conv = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=0) |
| 267 | + |
| 268 | + kernel = t_conv.weight.detach().cpu().numpy() |
| 269 | + bias = t_conv.bias.detach().cpu().numpy() |
| 270 | + |
| 271 | + # [inC, outC, kH, kW] -> [kH, kW, outC, inC] |
| 272 | + kernel = jnp.transpose(kernel, (2, 3, 1, 0)) |
| 273 | + |
| 274 | + key = random.key(0) |
| 275 | + x = random.normal(key, (1, 6, 6, 3)) |
| 276 | + |
| 277 | + # ConvTranspose expects the kernel to be [kH, kW, inC, outC], |
| 278 | + # but with `transpose_kernel=True`, it expects [kH, kW, outC, inC] instead |
| 279 | + j_conv = nnx.ConvTranspose(3, 4, kernel_size=(2, 2), padding='VALID', transpose_kernel=True, rngs=nnx.Rngs(0)) |
| 280 | + j_conv.kernel.value = kernel |
| 281 | + j_conv.bias.value = bias |
| 282 | + j_out = j_conv(x) |
| 283 | + |
| 284 | + # [N, H, W, C] -> [N, C, H, W] |
| 285 | + t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) |
| 286 | + t_out = t_conv(t_x) |
| 287 | + # [N, C, H, W] -> [N, H, W, C] |
| 288 | + t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) |
| 289 | + np.testing.assert_almost_equal(j_out, t_out, decimal=6) |
| 290 | + |
| 291 | +.. _`pull request`: https://github.com/jax-ml/jax/pull/5772 |
| 292 | + |
| 293 | +.. |nnx.ConvTranspose| replace:: ``nnx.ConvTranspose`` |
| 294 | +.. _nnx.ConvTranspose: https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.ConvTranspose |
| 295 | + |
| 296 | +.. |jax.lax.conv_transpose| replace:: ``jax.lax.conv_transpose`` |
| 297 | +.. _jax.lax.conv_transpose: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_transpose.html |
0 commit comments