diff --git a/torchrec/modules/crossnet.py b/torchrec/modules/crossnet.py
index 43771be5d..5188455d1 100644
--- a/torchrec/modules/crossnet.py
+++ b/torchrec/modules/crossnet.py
@@ -196,7 +196,7 @@ class VectorCrossNet(torch.nn.Module):
 
     On each layer l, the tensor is transformed into
 
-    .. math::    x_{l+1} = x_0 * (W_l . x_l + b_l) + x_l
+    .. math::    x_{l+1} = x_0 * (W_l . x_l) + b_l + x_l
 
     where :math:`W_l` is either a vector, :math:`*` means element-wise multiplication;
     :math:`.` means dot operations.