Skip to content

Commit 1534fee

Browse files
author
Andreas Georgiou
committed
v0.1
1 parent b76c6a6 commit 1534fee

File tree

11 files changed

+761
-0
lines changed

11 files changed

+761
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__

MANIFEST

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# file GENERATED by distutils, do NOT edit
2+
README
3+
setup.cfg
4+
setup.py
5+
diffdist/__init__.py
6+
diffdist/extra_collectives.py
7+
diffdist/functional.py
8+
diffdist/functions.py
9+
diffdist/modules.py
10+
diffdist/testing.py

README

Whitespace-only changes.

diffdist/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from diffdist import extra_collectives
2+
from diffdist import functional
3+
from diffdist import modules

diffdist/extra_collectives.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch.distributed as dist
2+
from torch.distributed import ReduceOp
3+
4+
5+
class AsyncOpList(object):
6+
def __init__(self, ops):
7+
self.ops = ops
8+
9+
def wait(self):
10+
for op in self.ops:
11+
op.wait()
12+
13+
def is_completed(self):
14+
for op in self.ops:
15+
if not op.is_completed():
16+
return False
17+
return True
18+
19+
20+
def reduce_scatter(tensor,
21+
tensor_list,
22+
op=ReduceOp.SUM,
23+
group=dist.group.WORLD,
24+
async_op=False):
25+
rank = dist.get_rank(group)
26+
if tensor is None:
27+
tensor = tensor_list[rank]
28+
if tensor.dim() == 0:
29+
tensor = tensor.view(-1)
30+
tensor[:] = tensor_list[rank]
31+
ops = []
32+
for i in range(dist.get_world_size(group)):
33+
if i == rank:
34+
tmp = dist.reduce(tensor, rank, op, group, async_op=True)
35+
else:
36+
tmp = dist.reduce(tensor_list[i], i, op, group, async_op=True)
37+
ops.append(tmp)
38+
39+
oplist = AsyncOpList(ops)
40+
if async_op:
41+
return oplist
42+
else:
43+
oplist.wait()

diffdist/functional.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import diffdist.modules as mods
2+
import torch.distributed as dist
3+
4+
5+
def consume_variable(tensor_to_consume, tensors_to_return, set_ones_grad=True):
6+
return mods.ConsumeVariable(set_ones_grad)(tensor_to_consume,
7+
*tensors_to_return)
8+
9+
10+
def send(tensor, dst, group=dist.group.WORLD, tag=0):
11+
return mods.Send(dst, group, tag)(tensor)
12+
13+
14+
def recv(tensor,
15+
src=None,
16+
group=dist.group.WORLD,
17+
tag=0,
18+
next_backprop=None,
19+
inplace=True):
20+
return mods.Recv(src, group, tag, next_backprop, inplace)(tensor)
21+
22+
23+
def broadcast(tensor,
24+
src,
25+
group=dist.group.WORLD,
26+
next_backprop=None,
27+
inplace=True):
28+
return mods.Broadcast(src, group, next_backprop, inplace)(tensor)
29+
30+
31+
def gather(tensor,
32+
gather_list=None,
33+
dst=None,
34+
group=dist.group.WORLD,
35+
next_backprop=None,
36+
inplace=True):
37+
return mods.Gather(dst, group, next_backprop, inplace)(tensor, gather_list)
38+
39+
40+
def scatter(tensor,
41+
scatter_list=None,
42+
src=None,
43+
group=dist.group.WORLD,
44+
next_backprop=None,
45+
inplace=True):
46+
return mods.Scatter(src, group, next_backprop, inplace)(tensor,
47+
scatter_list)
48+
49+
50+
def all_gather(gather_list,
51+
tensor,
52+
group=dist.group.WORLD,
53+
next_backprop=None,
54+
inplace=True):
55+
return mods.AllGather(group, next_backprop, inplace)(gather_list, tensor)

diffdist/functions.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
from torch.autograd import Function
2+
import diffdist.extra_collectives as dist_extra
3+
import torch.distributed as dist
4+
import torch
5+
6+
7+
class ConsumeVariableFunc(Function):
8+
@staticmethod
9+
def forward(ctx, tensor_to_consume, set_ones_grad, *tensors_to_return):
10+
ctx.save_for_backward(tensor_to_consume)
11+
ctx.set_ones_grad = set_ones_grad
12+
return tensors_to_return
13+
14+
@staticmethod
15+
def backward(ctx, *grad_outputs):
16+
tensor_to_consume, = ctx.saved_tensors
17+
if ctx.set_ones_grad:
18+
fake_grad = torch.ones_like(tensor_to_consume)
19+
else:
20+
fake_grad = torch.zeros_like(tensor_to_consume)
21+
22+
return (fake_grad, None) + grad_outputs
23+
24+
25+
class SendFunc(Function):
26+
@staticmethod
27+
def forward(ctx, tensor, dst, group=dist.group.WORLD, tag=0):
28+
ctx.save_for_backward(tensor)
29+
ctx.dst = dst
30+
ctx.group = group
31+
ctx.tag = tag
32+
dist.send(tensor, dst, group, tag)
33+
return tensor.new_tensor([])
34+
35+
@staticmethod
36+
def backward(ctx, grad_output):
37+
tensor, = ctx.saved_tensors
38+
# TODO: Add ctx.needs_input_grad check
39+
grad_tensor = torch.zeros_like(tensor)
40+
dist.recv(grad_tensor, ctx.dst, ctx.group, ctx.tag)
41+
42+
return grad_tensor, None, None, None
43+
44+
45+
class RecvFunc(Function):
46+
@staticmethod
47+
def forward(ctx,
48+
tensor,
49+
src=None,
50+
group=dist.group.WORLD,
51+
tag=0,
52+
inplace=True):
53+
if not inplace:
54+
tensor = torch.zeros_like(tensor).requires_grad_(False)
55+
ctx.src = src
56+
ctx.group = group
57+
ctx.tag = tag
58+
sender = dist.recv(tensor, src, group, tag)
59+
if src:
60+
assert sender == src
61+
else:
62+
ctx.src = sender
63+
sender = torch.tensor(sender)
64+
ctx.mark_non_differentiable(sender)
65+
return tensor, sender
66+
67+
@staticmethod
68+
def backward(ctx, grad_tensor, grad_sender):
69+
dist.send(grad_tensor, ctx.src, ctx.group, ctx.tag)
70+
return grad_tensor, None, None, None, None
71+
72+
73+
class BroadcastFunc(Function):
74+
@staticmethod
75+
def forward(ctx, tensor, src, group=dist.group.WORLD, inplace=True):
76+
ctx.src = src
77+
ctx.group = group
78+
if dist.get_rank(group) == src:
79+
if not inplace:
80+
with torch.no_grad():
81+
tensor = tensor.clone().requires_grad_(False)
82+
else:
83+
if not inplace:
84+
tensor = torch.zeros_like(tensor).requires_grad_(False)
85+
dist.broadcast(tensor, src, group)
86+
return tensor
87+
88+
@staticmethod
89+
def backward(ctx, grad_output):
90+
dist.reduce(grad_output,
91+
ctx.src,
92+
op=dist.ReduceOp.SUM,
93+
group=ctx.group)
94+
return grad_output, None, None, None
95+
96+
97+
class AllReduceFunc(Function):
98+
@staticmethod
99+
def forward(ctx, i):
100+
raise NotImplementedError
101+
102+
@staticmethod
103+
def backward(ctx, grad_output):
104+
raise NotImplementedError
105+
106+
107+
class ReduceFunc(Function):
108+
@staticmethod
109+
def forward(ctx, i):
110+
raise NotImplementedError
111+
112+
@staticmethod
113+
def backward(ctx, grad_output):
114+
raise NotImplementedError
115+
116+
117+
class AllGatherFunc(Function):
118+
@staticmethod
119+
def forward(ctx, tensor, group, inplace, *gather_list):
120+
ctx.save_for_backward(tensor)
121+
ctx.group = group
122+
gather_list = list(gather_list)
123+
if not inplace:
124+
gather_list = [torch.zeros_like(g) for g in gather_list]
125+
dist.all_gather(gather_list, tensor, group)
126+
return tuple(gather_list)
127+
128+
@staticmethod
129+
def backward(ctx, *grads):
130+
input, = ctx.saved_tensors
131+
grad_out = torch.zeros_like(input)
132+
dist_extra.reduce_scatter(grad_out, list(grads), group=ctx.group)
133+
return (grad_out, None, None) + grads
134+
135+
136+
class GatherFunc(Function):
137+
@staticmethod
138+
def forward(ctx, input, dst, group, inplace, *gather_list):
139+
ctx.dst = dst
140+
ctx.group = group
141+
ctx.save_for_backward(input)
142+
if dist.get_rank(group) == dst:
143+
gather_list = list(gather_list)
144+
if not inplace:
145+
gather_list = [torch.zeros_like(g) for g in gather_list]
146+
dist.gather(input, gather_list=gather_list, dst=dst, group=group)
147+
return tuple(gather_list)
148+
else:
149+
dist.gather(input, [], dst=dst, group=group)
150+
return input.new_tensor([])
151+
152+
@staticmethod
153+
def backward(ctx, *grads):
154+
input, = ctx.saved_tensors
155+
grad_input = torch.zeros_like(input)
156+
if dist.get_rank(ctx.group) == ctx.dst:
157+
grad_outputs = list(grads)
158+
dist.scatter(grad_input,
159+
grad_outputs,
160+
src=ctx.dst,
161+
group=ctx.group)
162+
return (grad_input, None, None, None) + grads
163+
else:
164+
dist.scatter(grad_input, [], src=ctx.dst, group=ctx.group)
165+
return grad_input, None, None, None, None
166+
167+
168+
class ScatterFunc(Function):
169+
@staticmethod
170+
def forward(ctx,
171+
tensor,
172+
src,
173+
group=dist.group.WORLD,
174+
inplace=True,
175+
*scatter_list):
176+
ctx.src = src
177+
ctx.group = group
178+
if not inplace:
179+
tensor = torch.zeros_like(tensor)
180+
if dist.get_rank(group) == src:
181+
ctx.save_for_backward(*scatter_list)
182+
scatter_list = list(scatter_list)
183+
dist.scatter(tensor, scatter_list, src=src, group=group)
184+
else:
185+
dist.scatter(tensor, [], src=src, group=group)
186+
return tensor
187+
188+
@staticmethod
189+
def backward(ctx, grad_tensor):
190+
if dist.get_rank(ctx.group) == ctx.src:
191+
grad_outputs = [torch.zeros_like(g) for g in ctx.saved_tensors]
192+
dist.gather(grad_tensor, grad_outputs, ctx.src, group=ctx.group)
193+
return (grad_tensor, None, None, None) + tuple(grad_outputs)
194+
else:
195+
dist.gather(grad_tensor, [], ctx.src, group=ctx.group)
196+
return grad_tensor, None, None, None, None

0 commit comments

Comments
 (0)