Skip to content

Commit

Permalink
Fix r_order can not be 0 bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
xin-w8023 committed Jul 20, 2020
1 parent 320c157 commit 9c628dd
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 21 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.idea
.idea
__pycache__
30 changes: 30 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import torch.nn as nn

from module import FSMNKernel


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fsmn = FSMNKernel(220, 0, 0, 1, 1)
self.fc = nn.Linear(220, 3)

def forward(self, x):
return self.fc(nn.functional.relu(self.fsmn(x)))


if __name__ == '__main__':
B, T, D = 32, 300, 220
xx = torch.randn(B, T, D)
yy = torch.randint(low=0, high=3, size=(32, 300)).view(-1)
net = Net()
optim = torch.optim.SGD(net.parameters(), lr=1e-2, momentum=0.8)
L = nn.CrossEntropyLoss()
for _ in range(100):
y = net(xx).view(-1, 3)
loss = L(y, yy)
print(loss.item())
optim.zero_grad()
loss.backward()
optim.step()
1 change: 1 addition & 0 deletions module/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .fsmn_kernel import FSMNKernel
35 changes: 15 additions & 20 deletions fsmn_kernel.py → module/fsmn_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,21 @@ def forward(self, x):

kernel_out = []
for frame in range(self.l_order * self.l_stride, frames + self.l_order * self.l_stride):
cur_frame = torch.sum(
x[:, frame - self.l_order * self.l_stride:frame:self.l_stride] * self.filter[:self.l_order] +
x[:, frame:frame + 1] * self.filter[self.l_order:self.l_order + 1] +
x[:, frame + 1:frame + 1 + self.r_order * self.r_stride:self.r_stride] * self.filter[self.l_order + 1:],
dim=1,
keepdim=True
)
if self.r_order > 0:
cur_frame = torch.sum(
x[:, frame - self.l_order * self.l_stride:frame:self.l_stride] * self.filter[:self.l_order] +
x[:, frame:frame + 1] * self.filter[self.l_order:self.l_order + 1] +
x[:, frame + 1:frame + 1 + self.r_order * self.r_stride:self.r_stride] * self.filter[self.l_order + 1:],
dim=1,
keepdim=True
)
else:
cur_frame = torch.sum(
x[:, frame - self.l_order * self.l_stride:frame:self.l_stride] * self.filter[:self.l_order] +
x[:, frame:frame + 1] * self.filter[self.l_order:self.l_order + 1],
dim=1,
keepdim=True
)
kernel_out.append(cur_frame)
kernel_out = torch.cat(kernel_out, dim=1)
return kernel_out


if __name__ == '__main__':
B, T, D = 32, 300, 440
xx = torch.randn(B, T, D, requires_grad=True)
kernel = FSMNKernel(D, 5, 5, 1, 2)
out = kernel(xx)
print(kernel)
loss = out.mean()
loss.backward()
for name, param in kernel.named_parameters():
print(name, param, param.grad)

0 comments on commit 9c628dd

Please sign in to comment.