Skip to content

Commit

Permalink
Added kernel residual param.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxin.colin committed Aug 17, 2021
1 parent 0a27f7c commit 3b44225
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 115 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
FSMN implementation with PyTorch

## Add FSMNKernelParallel version with group convolution to speed up fsmn computation.
```python
########################################
# diff: sum(fsmnp_out - fsmn_out) = 0.0
########################################
# parallel time used: 0.26619601249694824
# for-loop time used: 6.988285303115845
########################################
```plain
################################################################################
maximum relative error: max(abs((fsmnp_out - fsmn_out)/ fsmnp_out)) = 0.00000056
################################################################################
parallel fsmn kernel time used: 0.66744542
for-loop fsmn kernel time used: 4.57218981
################################################################################
```
32 changes: 18 additions & 14 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,38 @@ def forward(self, x):

def test():
import time
lo, ro = 2, 2
B, T, D = 10, 100, 30
num_iter = 1000
lo, ro = 10, 10
B, T, D = 10, 2000, 440
num_iter = 10
x = torch.arange(B*T*D).view(B, T, D).float()

fsmnp = FSMNKernelParallel(D, lo, ro, padding_mode='zero')
fsmnp.filter.weight.data = torch.arange(fsmnp.filter.weight.numel()).view(fsmnp.filter.weight.size()).float()
print('fsmnp filter:', fsmnp.filter.weight.data)
fsmnp.filter.weight.data /= torch.max(fsmnp.filter.weight.data)

s = time.time()
fsmnp_out = x
for _ in range(num_iter):
fsmnp_out = fsmnp(x)
fsmnp_out = fsmnp(fsmnp_out)
fsmnp_time = time.time() - s

fsmn = FSMNKernel(dims=D, l_order=lo, r_order=ro, l_stride=1, r_stride=1)
fsmn.filter.data = nn.Parameter(torch.arange(fsmnp.filter.weight.numel()).view(D, lo+ro+1).transpose(0, 1).float())
print('fsmn filter:', fsmn.filter.data)
fsmn.filter.data /= torch.max(fsmn.filter.data)

s = time.time()
fsmn_out = x
for _ in range(num_iter):
fsmn_out = fsmn(x)
fsmn_out = fsmn(fsmn_out)
fsmn_time = time.time() - s
print('#' * 80)
print(f'maximum relative error: max(abs((fsmnp_out - fsmn_out)/ fsmnp_out)) ='
f' {torch.max((torch.abs(fsmnp_out - fsmn_out) / (fsmnp_out + 1e-8))):.8f}')

print('#' * 40)
print(f'diff: sum(fsmnp_out - fsmn_out) = {torch.sum(fsmnp_out - fsmn_out)}')

print('#' * 40)
print(f'parallel time used: {fsmnp_time}\n'
f'for-loop time used: {fsmn_time}')
print('#' * 40)
print('#' * 80)
print(f'parallel fsmn kernel time used: {fsmnp_time:.8f}\n'
f'for-loop fsmn kernel time used: {fsmn_time:.8f}')
print('#' * 80)


if __name__ == '__main__':
Expand Down
197 changes: 103 additions & 94 deletions module/fsmn_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,104 +5,113 @@


class Pad(enum.Enum):
ZERO = 'zero'
EDGE = 'edge'
ZERO = 'zero'
EDGE = 'edge'


class FSMNKernel(nn.Module):
def __init__(self, dims, l_order, r_order, l_stride=1, r_stride=1):
super().__init__()
self.filter = nn.Parameter(torch.randn(l_order + r_order + 1, dims))
self.l_order = l_order
self.r_order = r_order
self.l_stride = l_stride
self.r_stride = r_stride
self.dims = dims

def extra_repr(self) -> str:
return f'l_order={self.l_order}, r_order={self.r_order}, ' \
f'l_stride={self.l_stride}, r_stride={self.r_stride}, ' \
f'dims={self.dims}'

def forward(self, x):
"""Apply FSMN-kernel to x.
:param x: BxTxD
:return: BxTxD
"""
batch, frames, _ = x.size()
# pad zeros, BxTxD -> Bx(l_order+T+r_order)xD
x = torch.cat(
(
torch.zeros((batch, self.l_order * self.l_stride, self.dims), device=x.device),
x,
torch.zeros((batch, self.r_order * self.r_stride, self.dims), device=x.device)
),
dim=1)

kernel_out = []
for frame in range(self.l_order * self.l_stride, frames + self.l_order * self.l_stride):
l_frame = torch.sum(
x[:, frame - self.l_order * self.l_stride:frame:self.l_stride] * self.filter[:self.l_order],
dim=1,
keepdim=True
)
c_frame = x[:, frame:frame + 1] * self.filter[self.l_order:self.l_order + 1]
cur_frame = l_frame + c_frame
if self.r_order > 0:
r_frame = torch.sum(
x[:, frame + 1:frame + 1 + self.r_order * self.r_stride:self.r_stride] * self.filter[
self.l_order + 1:],
dim=1,
keepdim=True
)
cur_frame = cur_frame + r_frame
kernel_out.append(cur_frame)
kernel_out = torch.cat(kernel_out, dim=1)
return kernel_out
def __init__(self, dims, l_order, r_order, l_stride=1, r_stride=1, kernel_res=False):
super().__init__()
self.filter = nn.Parameter(torch.randn(l_order + r_order + 1, dims))
self.l_order = l_order
self.r_order = r_order
self.l_stride = l_stride
self.r_stride = r_stride
self.dims = dims
self.kernel_res = kernel_res

def extra_repr(self) -> str:
return f'l_order={self.l_order}, r_order={self.r_order}, ' \
f'l_stride={self.l_stride}, r_stride={self.r_stride}, ' \
f'kernel_res={self.kr}, dims={self.dims}'

def forward(self, inputs):
"""Apply FSMN-kernel to x.
:param x: BxTxD
:return: BxTxD
"""
batch, frames, _ = inputs.size()
# pad zeros, BxTxD -> Bx(l_order+T+r_order)xD
x = torch.cat(
(
torch.zeros((batch, self.l_order * self.l_stride, self.dims), device=inputs.device),
inputs,
torch.zeros((batch, self.r_order * self.r_stride, self.dims), device=inputs.device)
),
dim=1)

kernel_out = []
for frame in range(self.l_order * self.l_stride, frames + self.l_order * self.l_stride):
l_frame = torch.sum(
x[:, frame - self.l_order * self.l_stride:frame:self.l_stride] * self.filter[:self.l_order],
dim=1,
keepdim=True
)
c_frame = x[:, frame:frame + 1] * self.filter[self.l_order:self.l_order + 1]
cur_frame = l_frame + c_frame
if self.r_order > 0:
r_frame = torch.sum(
x[:, frame + 1:frame + 1 + self.r_order * self.r_stride:self.r_stride] * self.filter[
self.l_order + 1:],
dim=1,
keepdim=True
)
cur_frame = cur_frame + r_frame
kernel_out.append(cur_frame)
kernel_out = torch.cat(kernel_out, dim=1)

if self.kernel_res:
kernel_out += inputs

return kernel_out


class FSMNKernelParallel(nn.Module):

def __init__(self, dims, l_order, r_order, l_stride=1, r_stride=1, padding_mode=Pad.ZERO):
super().__init__()
assert l_stride == r_stride == 1, f'Parallel version expected l_stride == r_stride == 1, ' \
f'but get ({l_stride}, {r_stride})'
self.filter = nn.Conv1d(in_channels=dims, out_channels=dims, kernel_size=l_order+r_order+1, stride=l_stride,
groups=dims, padding=0, bias=False)
self.l_order = l_order
self.r_order = r_order
self.l_stride = l_stride
self.r_stride = r_stride
self.dims = dims
self.pm = Pad(padding_mode)

def extra_repr(self) -> str:
return f'(l_order={self.l_order}, r_order={self.r_order}, ' \
f'l_stride={self.l_stride}, r_stride={self.r_stride}, ' \
f'dims={self.dims}, padding_mode={self.pm})'

def forward(self, x):
batch, time, dim = x.size()
if self.pm is Pad.ZERO:
x = torch.cat(
(
torch.zeros((batch, self.l_order * self.l_stride, self.dims), device=x.device),
x,
torch.zeros((batch, self.r_order * self.r_stride, self.dims), device=x.device)
),
dim=1
)
elif self.pm is Pad.EDGE:
x = torch.cat(
(
torch.ones((batch, self.l_order * self.l_stride, self.dims), device=x.device) * x.data[:, 0],
x,
torch.ones((batch, self.r_order * self.r_stride, self.dims), device=x.device) * x.data[:, -1]
),
dim=1
)
else:
raise ValueError(f'padding mode {self.pm} is not supported for now.')
x = x.transpose(1, 2).contiguous() # BxTxD -> BxDxT for conv accept channel as second dimension.
y = self.filter(x).transpose(1, 2).contiguous() # BxDxT -> BxTxD
return y
def __init__(self, dims, l_order, r_order, l_stride=1, r_stride=1, kernel_res=False, padding_mode=Pad.ZERO):
super().__init__()
assert l_stride == r_stride == 1, f'Parallel version expected l_stride == r_stride == 1, ' \
f'but get ({l_stride}, {r_stride})'
self.filter = nn.Conv1d(in_channels=dims, out_channels=dims, kernel_size=l_order+r_order+1, stride=l_stride,
groups=dims, padding=0, bias=False)
self.l_order = l_order
self.r_order = r_order
self.l_stride = l_stride
self.r_stride = r_stride
self.dims = dims
self.pm = Pad(padding_mode)
self.kernel_res = kernel_res

def extra_repr(self) -> str:
return f'(l_order={self.l_order}, r_order={self.r_order}, ' \
f'l_stride={self.l_stride}, r_stride={self.r_stride}, ' \
f'dims={self.dims}, kernel_res={self.kr}, padding_mode={self.pm})'

def forward(self, inputs):

batch, time, dim = inputs.size()
if self.pm is Pad.ZERO:
x = torch.cat(
(
torch.zeros((batch, self.l_order * self.l_stride, self.dims), device=inputs.device),
inputs,
torch.zeros((batch, self.r_order * self.r_stride, self.dims), device=inputs.device)
),
dim=1
)
elif self.pm is Pad.EDGE:
x = torch.cat(
(
torch.ones((batch, self.l_order * self.l_stride, self.dims), device=inputs.device) * inputs.data[:, 0],
inputs,
torch.ones((batch, self.r_order * self.r_stride, self.dims), device=inputs.device) * inputs.data[:, -1]
),
dim=1
)
else:
raise ValueError(f'padding mode {self.pm} is not supported for now.')
x = x.transpose(1, 2).contiguous() # BxTxD -> BxDxT for conv accept channel as second dimension.
y = self.filter(x).transpose(1, 2).contiguous() # BxDxT -> BxTxD
if self.kernel_res:
y += inputs
return y

0 comments on commit 3b44225

Please sign in to comment.