-
Notifications
You must be signed in to change notification settings - Fork 11k
Description
Bug description
ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb the implementation in
def forward(self, x):
Q = torch.einsum("bnd,di->bni", x, self.W_query)
K = torch.einsum("bnd,di->bni", x, self.W_key)
V = torch.einsum("bnd,di->bni", x, self.W_value)
should be revised into
Q = torch.einsum("bnd,od->bno", x, self.W_query)
K = torch.einsum("bnd,od->bno", x, self.W_key)
V = torch.einsum("bnd,od->bno", x, self.W_value)
because the initialization is self.W_query = nn.Parameter(torch.randn(d_out, d_in)), first d_out then d_in.
This bug is hidden at first because of the initialization of params
mha_einsum = MHAEinsum(
d_in=embed_dim,
d_out=embed_dim,
...) where d_in == d_out == embed_dim, which hidden the bug. If you change it into d_out = 2 * embed_dim, then there will be an error.
What operating system are you using?
None
Where do you run your code?
None
Environment