forked from alxndrTL/mamba.py
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpscan.py
126 lines (90 loc) · 3.6 KB
/
pscan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import math
import torch
"""
An implementation of the parallel scan operation in PyTorch (Blelloch version).
This code follows the skeleton proposed by Francois Fleuret in his pscan. However, the keys differences are :
-it has been written in an iterative way (rather than recursive)
-the backward pass has been rewritten
Please see docs/pscan.ipynb for a detailed explanation of what happens here.
"""
# TODO eviter les .flip() en codant un pscan reverse (avec flag)
class PScan(torch.autograd.Function):
@staticmethod
def pscan(A, X):
# A : (B, D, L, N)
# X : (B, D, L, N)
# modifies X in place by doing a parallel scan.
# more formally, X will be populated by these values :
# H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
# which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
B, D, L, _ = A.size()
num_steps = int(math.log2(L))
# up sweep or reduction step
Aa = A
Xa = X
for k in range(num_steps):
T = 2 * (Xa.size(2) // 2)
Aa = Aa[:, :, :T].view(B, D, T//2, 2, -1)
Xa = Xa[:, :, :T].view(B, D, T//2, 2, -1)
Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])
Aa = Aa[:, :, :, 1]
Xa = Xa[:, :, :, 1]
# down sweep
for k in range(num_steps-1, -1, -1):
Aa = A[:, :, 2**k-1:L:2**k]
Xa = X[:, :, 2**k-1:L:2**k]
T = 2 * (Xa.size(2) // 2)
if T < Xa.size(2):
Xa[:, :, -1].add_(Aa[:, :, -1].mul(Xa[:, :, -2]))
Aa[:, :, -1].mul_(Aa[:, :, -2])
Aa = Aa[:, :, :T].view(B, D, T//2, 2, -1)
Xa = Xa[:, :, :T].view(B, D, T//2, 2, -1)
Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])
@staticmethod
def forward(ctx, A_in, X_in):
"""
Applies the parallel scan operation, as defined above. Returns a new tensor.
Args:
A_in : (B, L, D, N)
X_in : (B, L, D, N)
Returns:
H : (B, L, D, N)
"""
# clone tensor (in-place ops)
A = A_in.clone() # (B, L, D, N)
X = X_in.clone() # (B, L, D, N)
# prepare tensors
A = A.transpose(2, 1) # (B, D, L, N)
X = X.transpose(2, 1) # (B, D, L, N)
# parallel scan
PScan.pscan(A, X)
ctx.save_for_backward(A_in, X)
return X.transpose(2, 1)
@staticmethod
def backward(ctx, grad_output_in):
"""
Flows the gradient from the output to the input. Returns two new tensors.
Args:
ctx : A_in : (B, L, D, N), X : (B, D, L, N)
grad_output_in : (B, L, D, N)
Returns:
gradA : (B, L, D, N), gradX : (B, L, D, N)
"""
A_in, X = ctx.saved_tensors
# clone tensors
A = A_in.clone()
# grad_output_in will be cloned with flip()
# prepare tensors
A = A.transpose(2, 1) # (B, D, L, N)
A = torch.cat((A[:, :, :1], A[:, :, 1:].flip(2)), dim=2)
grad_output_b = grad_output_in.transpose(2, 1)
# reverse parallel scan
grad_output_b = grad_output_b.flip(2)
PScan.pscan(A, grad_output_b)
grad_output_b = grad_output_b.flip(2)
Q = torch.zeros_like(X)
Q[:, :, 1:].add_(X[:, :, :-1] * grad_output_b[:, :, 1:])
return Q.transpose(2, 1), grad_output_b.transpose(2, 1)
pscan = PScan.apply