-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathring_allreduce.py
executable file
·79 lines (69 loc) · 3.05 KB
/
ring_allreduce.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import torch
import logging
import time
from torch import distributed as dist
DEVICE = "cpu"
#TENSOR_SIZE = 1024
def init_process(master_ip, rank, world_size):
dist.init_process_group(backend="gloo",
init_method="tcp://" + master_ip + ":6585",
rank=rank,
world_size=world_size)
def main(tensor_size):
# Get world size and rank
world_size = dist.get_world_size()
rank = dist.get_rank()
# Create a random tensor
t = torch.rand(tensor_size)
t = list(t.split(int(tensor_size/world_size)))
# Create send and receive buffers
zero_buffer = torch.zeros(tensor_size)
recv_buffers = list(zero_buffer.split(int(tensor_size/world_size)))
total_time = 0
# Reduce-scatter loop
for i in range(1, world_size):
s = time.time()
if (rank % 2) == 0:
# Send a tensor to the previous machine
#print((rank + i) % world_size)
dist.send(t[(rank + i) % world_size], dst=(rank + world_size - 1) % world_size)
# Receive a tensor from the next machine
dist.recv(recv_buffers[i-1], src=(rank + 1) % world_size)
else:
# Receive a tensor from the next machine
dist.recv(recv_buffers[i-1], src=(rank + 1) % world_size)
# Send a tensor to the previous machine
dist.send(t[(rank + i) % world_size], dst=(rank + world_size - 1) % world_size)
e = time.time()
total_time += e - s
# Accumulate value in t. At the end of the for loop, t will hold the reduced value
t[(rank + i + 1) % world_size] += recv_buffers[i-1]
# All-gather loop
for i in range(1, world_size):
s = time.time()
if (rank % 2) == 0:
# Send a tensor to the next machine
dist.send(t[(rank + 1 - i + world_size) % world_size], dst=(rank + 1) % world_size)
# Receive a tensor from the previous machine
dist.recv(t[(rank - i + world_size) % world_size], src=(rank + world_size - 1) % world_size)
else:
# Receive a tensor from the previous machine
dist.recv(t[(rank - i + world_size) % world_size], src=(rank + world_size - 1) % world_size)
# Send a tensor to the next machine
dist.send(t[(rank + 1 - i + world_size) % world_size], dst=(rank + 1) % world_size)
e = time.time()
total_time += e - s
print("Rank", rank, "finished allreduce in ", total_time, " seconds")
#print(torch.cat(t))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--master-ip", "-m", required=True, type=str)
parser.add_argument("--num-nodes", "-n", required=True, type=int)
parser.add_argument("--rank", "-r", required=True, type=int)
parser.add_argument("--tensor-size", "-t", required=True, type=int)
args = parser.parse_args()
init_process(master_ip=args.master_ip,
rank=args.rank,
world_size=args.num_nodes)
main(tensor_size=args.tensor_size)