Skip to content

Multi-head attention with Einsum Q = torch.einsum("bnd,di->bni", x, self.W_query) bug #857

@yujie-jia

Description

@yujie-jia

Bug description

ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb the implementation in
def forward(self, x):

Q = torch.einsum("bnd,di->bni", x, self.W_query)
K = torch.einsum("bnd,di->bni", x, self.W_key)
V = torch.einsum("bnd,di->bni", x, self.W_value)

should be revised into
Q = torch.einsum("bnd,od->bno", x, self.W_query)
K = torch.einsum("bnd,od->bno", x, self.W_key)
V = torch.einsum("bnd,od->bno", x, self.W_value)
because the initialization is self.W_query = nn.Parameter(torch.randn(d_out, d_in)), first d_out then d_in.

This bug is hidden at first because of the initialization of params
mha_einsum = MHAEinsum(
d_in=embed_dim,
d_out=embed_dim,
...) where d_in == d_out == embed_dim, which hidden the bug. If you change it into d_out = 2 * embed_dim, then there will be an error.

What operating system are you using?

None

Where do you run your code?

None

Environment




Metadata

Metadata

Labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions