Skip to content

Commit 9df202b

Browse files
author
Emrick Sinitambirivoutin
committed
Add some details about the implementation in PyTorch
1 parent 9e6a043 commit 9df202b

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

Diff for: batch_normalization/notes/batch_normalization.md

+15-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Batch normalization [1] overcomes this issue and make the training more efficien
1414

1515
## 1. Reduce internal covariance shift via mini-batch statistics
1616

17-
One way to reduce remove the ill effects of the internal covariance shift within a Neural Network is to normalize layers inputs. This operation not only enforces inputs to have the same distribution but also whitens each of them. This method is motivated by some studies [3,4] showing that the network training converges faster if its inputs are whitened and as a consequence, enforcing the whithening of the inputs of each layers is a desirable property for the network.
17+
One way to reduce remove the ill effects of the internal covariance shift within a Neural Network is to normalize layers inputs. This operation not only enforces inputs to have the same distribution but also whitens each of them. This method is motivated by some studies [2] showing that the network training converges faster if its inputs are whitened and as a consequence, enforcing the whithening of the inputs of each layers is a desirable property for the network.
1818

1919
However, the full whitening of each layer’s inputs is costly and not fully differentiable. Batch normalization overcomes this issue by considering two assumptions:
2020

@@ -47,6 +47,10 @@ $$
4747

4848
#### Fully connected layers
4949

50+
The implementation for fully connected layers is pretty simple. We just need to get the mean and the variance of each batches and then to scale and shift the feature map with the alpha and the beta parameters presented earlier.
51+
52+
During the backward pass, we will use backpropagation in order to update these two parameters.
53+
5054
```python
5155
mean = torch.mean(X, axis=0)
5256
variance = torch.mean((X-mean)**2, axis=0)
@@ -56,6 +60,8 @@ out = gamma * X_hat + beta
5660

5761
#### Convolutional layers
5862

63+
The implementation for convolutional layers is almost the same as before. We just need to perform some reshaping in order to adapt to the input that we get from the previous layer.
64+
5965
```python
6066
N, C, H, W = X.shape
6167
mean = torch.mean(X, axis = (0, 2, 3))
@@ -64,7 +70,9 @@ X_hat = (X - mean.reshape((1, C, 1, 1))) * 1.0 / torch.sqrt(variance.reshape((1,
6470
out = gamma.reshape((1, C, 1, 1)) * X_hat + beta.reshape((1, C, 1, 1))
6571
```
6672

73+
In PyTorch, the backpropagation is very easy to handle, one important thing here is to specify that our alpha and beta are parameters in order to update them during the backward phase.
6774

75+
To do so, we will declare them as `nn.Parameter()` in our layer and we will initialize them with random values.
6876

6977
### During inference
7078

@@ -81,8 +89,14 @@ $$
8189

8290
This moving average is stored in a global variable that is updated during the training phase.
8391

92+
In order to store this moving average in our layer during training, we can use buffers. We initiate these buffer when we instanciate our layer with the method `register_buffer()` of PyTorch.
93+
8494
### Final module
8595

96+
Then final module is then composed of all of the previous blocks that we discribed earlier. We add a condition over the shape of the input data in order to know wether we are dealing with a fully connected layer or a convolutional layer.
97+
98+
One important thing to notice here is that we only need to implement the `forward()` method. As our class inherits from the `nn.Module`, the `backward()` function will be automaticly inherited from this class (thank you PyTorch ❤️).
99+
86100
```python
87101
class CustomBatchNorm(nn.Module):
88102

0 commit comments

Comments
 (0)