Skip to content

Commit 6125fca

Browse files
v0 param server (using collectives not object store)
ghstack-source-id: b30dce25ddaafb4e9cbb5886b85f5a0904494456 Pull Request resolved: #2865
1 parent 7df8317 commit 6125fca

File tree

4 files changed

+465
-6
lines changed

4 files changed

+465
-6
lines changed

param_server_weight_updater.py

+263
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
import ray
2+
3+
from argparse import ArgumentParser
4+
from functools import partial
5+
6+
import torch
7+
from datasets import load_dataset
8+
from tensordict import TensorDict
9+
from torch.utils.data import DataLoader
10+
from torchrl.collectors.weight_update import RayRemoteWeightUpdater
11+
from transformers import AutoTokenizer, AutoModel
12+
from vllm import LLM
13+
14+
from vllm.utils import get_ip, get_open_port
15+
16+
from torchrl.collectors.distributed import RayCollector
17+
from torchrl.envs import LLMEnv
18+
from torchrl.modules import from_vllm
19+
20+
from torchrl.collectors.vllm_weight_update import vLLMHFLocalWeightUpdater, vLLMRemoteWeightUpdaterBase, WorkerExtension
21+
22+
parser = ArgumentParser()
23+
parser.add_argument("--dataset", type=str, default="gsm8k")
24+
parser.add_argument("--batch_size", type=int, default=4)
25+
parser.add_argument("--epochs", type=int, default=10)
26+
parser.add_argument("--repeats", type=int, default=10)
27+
parser.add_argument("--steps_per_batch", type=int, default=16)
28+
parser.add_argument("--optim_batch_size", type=int, default=4)
29+
30+
31+
def make_policy():
32+
inference_model = LLM(
33+
"facebook/opt-125m",
34+
enforce_eager=True,
35+
# change to worker_extension_cls when available in stable release
36+
worker_cls=WorkerExtension,
37+
)
38+
39+
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
40+
tokenizer.pad_token = tokenizer.eos_token
41+
tokenizer.padding_side = "left"
42+
43+
policy = from_vllm(
44+
inference_model, tokenizer=tokenizer, from_text=False, generate=True, return_log_probs=True, generate_kwargs={"temperature": 0.0})
45+
return policy
46+
47+
48+
def make_env(dataset, batch_size):
49+
dataset = load_dataset(dataset, "main")
50+
train_dataset = dataset["train"]
51+
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
52+
tokenizer.pad_token = tokenizer.eos_token
53+
tokenizer.padding_side = "left"
54+
55+
# Env
56+
dataloader = DataLoader( # noqa: TOR401
57+
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
58+
)
59+
env = LLMEnv.from_dataloader(
60+
dataloader=dataloader,
61+
tokenizer=tokenizer,
62+
str2str=True,
63+
batch_size=(args.batch_size * args.repeats,),
64+
repeats=args.repeats, )
65+
return env
66+
67+
68+
def collate_fn(batch):
69+
batch = torch.stack([TensorDict.from_dict(_batch) for _batch in batch])
70+
batch.rename_key_("question", "text")
71+
return batch
72+
73+
@ray.remote(num_cpus=1, num_gpus=1)
74+
class TrainerActor:
75+
def __init__(self, model, env_vars):
76+
import os
77+
import torch
78+
import torch.distributed
79+
from torch.distributed._composable.fsdp import fully_shard
80+
81+
torch.cuda.set_device(torch.device('cuda', 0))
82+
83+
for var in env_vars:
84+
os.environ[var] = str(env_vars[var])
85+
86+
if not torch.distributed.is_initialized():
87+
torch.distributed.init_process_group(backend="nccl", device_id=torch.device('cuda:0'))
88+
print("initialized process group")
89+
90+
world_size = torch.distributed.get_world_size()
91+
rank = torch.distributed.get_rank()
92+
print(world_size, rank)
93+
self.rank = int(os.environ["RANK"])
94+
self.world_size = int(os.environ["WORLD_SIZE"])
95+
96+
97+
# hold back one rank for the parameter server
98+
self.fsdp_group = torch.distributed.new_group(ranks=list(range(self.world_size - 1)))
99+
self.device_mesh = torch.distributed.device_mesh.DeviceMesh.from_group(self.fsdp_group, device_type="cuda")
100+
101+
self.model = AutoModel.from_pretrained(model).cuda()
102+
103+
fully_shard(self.model, mesh=self.device_mesh)
104+
105+
def register_parameter_server(self, param_server):
106+
assert self.rank == 0
107+
self.param_server = param_server
108+
109+
def send_weights_to_param_server(self):
110+
if self.rank == 0:
111+
ray.get(self.param_server.acquire_state_dict_lock.remote())
112+
self.param_server.receive_from_trainer.remote()
113+
for k, v in self.model.state_dict().items():
114+
replicated_v = v.full_tensor()
115+
if self.rank == 0:
116+
# dst is global rank, can switch to group_dst arg if not 2.5.1
117+
torch.distributed.send(replicated_v, dst=2)
118+
if self.rank == 0:
119+
ray.get(self.param_server.release_state_dict_lock.remote())
120+
121+
def zero_(self):
122+
sd = self.model.state_dict()
123+
for k, v in sd.items():
124+
sd[k] = v.data.zero_()
125+
126+
def train(self):
127+
import time
128+
for _ in range(1):
129+
# actually run train loop
130+
# ...
131+
self.zero_()
132+
torch.distributed.barrier(group=self.fsdp_group)
133+
self.send_weights_to_param_server()
134+
torch.distributed.barrier(group=self.fsdp_group)
135+
136+
137+
@ray.remote(num_cpus=1, num_gpus=1)
138+
class vLLMParameterServer(vLLMRemoteWeightUpdaterBase):
139+
def __init__(self, model, vllm_master_address, vllm_master_port, env_vars):
140+
super().__init__(model, vllm_master_address, vllm_master_port)
141+
import os
142+
import torch
143+
import torch.distributed
144+
145+
torch.cuda.set_device(torch.device('cuda', 0))
146+
147+
for var in env_vars:
148+
os.environ[var] = str(env_vars[var])
149+
150+
if not torch.distributed.is_initialized():
151+
torch.distributed.init_process_group(backend="nccl", device_id=torch.device('cuda:0'))
152+
153+
self.rank = int(os.environ["RANK"])
154+
self.world_size = int(os.environ["WORLD_SIZE"])
155+
assert self.rank == self.world_size - 1
156+
157+
self.fsdp_group = torch.distributed.new_group(ranks=list(range(self.world_size - 1)))
158+
159+
def receive_from_trainer(self):
160+
for k, v in self.state_dict.items():
161+
torch.distributed.recv(v, src=0)
162+
163+
def _skip_update(self, worker_id: int) -> bool:
164+
pass
165+
166+
def check_weights_changed(self):
167+
"""
168+
Check if the weights are updated to 0.
169+
"""
170+
weights_updated = True
171+
for name, p in self.state_dict.items():
172+
weights_updated = weights_updated and torch.allclose(
173+
p, torch.zeros_like(p))
174+
return weights_updated
175+
176+
177+
178+
def _create_trainer_group(
179+
worker_cls,
180+
param_server_cls,
181+
world_size: int,
182+
vllm_master_address,
183+
vllm_master_port,
184+
model,
185+
):
186+
addr, port = get_ip(), get_open_port()
187+
trainer_workers = []
188+
fsdp_world_size = world_size - 1
189+
for i in range(fsdp_world_size):
190+
env_vars = {
191+
"RANK": str(i),
192+
"WORLD_SIZE": world_size,
193+
"MASTER_ADDR": str(addr),
194+
"MASTER_PORT": str(port),
195+
}
196+
worker = worker_cls.remote(model, env_vars)
197+
trainer_workers.append(worker)
198+
199+
env_vars = {
200+
"RANK": str(world_size - 1),
201+
"WORLD_SIZE": world_size,
202+
"MASTER_ADDR": str(addr),
203+
"MASTER_PORT": str(port),
204+
}
205+
parameter_server = param_server_cls.remote(model, vllm_master_address, vllm_master_port, env_vars)
206+
trainer_workers[0].register_parameter_server.remote(parameter_server)
207+
return trainer_workers, parameter_server
208+
209+
210+
if __name__ == "__main__":
211+
args = parser.parse_args()
212+
213+
remote_configs = {
214+
"num_cpus": 1,
215+
"num_gpus": 1,
216+
"memory": 2 * 1024**3,
217+
}
218+
219+
model = "facebook/opt-125m"
220+
221+
ray.init(num_cpus=4, num_gpus=4)
222+
223+
vllm_master_address, vllm_update_port = get_ip(), get_open_port()
224+
225+
trainer_workers, parameter_server = _create_trainer_group(
226+
TrainerActor,
227+
vLLMParameterServer,
228+
3,
229+
vllm_master_address,
230+
vllm_update_port,
231+
model,
232+
)
233+
234+
handles = []
235+
for trainer_worker in trainer_workers:
236+
handles.append(trainer_worker.train.remote())
237+
238+
model_metadata = ray.get(parameter_server.get_model_metadata.remote())
239+
local_weight_updater = vLLMHFLocalWeightUpdater(vllm_master_address, vllm_update_port, model_metadata)
240+
241+
make_env_parsed = partial(make_env, batch_size=args.batch_size, dataset=args.dataset)
242+
collector = RayCollector(
243+
[make_env_parsed],
244+
policy_factory=make_policy,
245+
frames_per_batch=40,
246+
total_frames=200,
247+
remote_configs=remote_configs,
248+
remote_weight_updater=parameter_server,
249+
collector_kwargs={
250+
"local_weight_updater": local_weight_updater,
251+
},
252+
update_after_each_batch=True,
253+
)
254+
print("done collector init")
255+
256+
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
257+
258+
for i, data in enumerate(collector):
259+
print(tokenizer.decode(data["tokens"][0].squeeze()))
260+
print(tokenizer.decode(data["tokens_response"][0].squeeze()))
261+
if i == 1:
262+
break
263+
collector.shutdown()

torchrl/collectors/collectors.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ def cudagraph_mark_step_begin():
7676
"""Placeholder for missing cudagraph_mark_step_begin method."""
7777
raise NotImplementedError("cudagraph_mark_step_begin not implemented.")
7878

79+
try:
80+
import ray
81+
from ray.actor import ActorHandle
82+
83+
_has_ray = True
84+
except ImportError as err:
85+
_has_ray = False
86+
RAY_ERR = err
87+
7988

8089
_TIMEOUT = 1.0
8190
INSTANTIATE_TIMEOUT = 20
@@ -174,9 +183,12 @@ def remote_weight_updater(self) -> RemoteWeightUpdaterBase:
174183
@remote_weight_updater.setter
175184
def remote_weight_updater(self, value: RemoteWeightUpdaterBase | None):
176185
if value is not None:
177-
value.register_collector(self)
178-
if value.collector is not self:
179-
raise RuntimeError("Failed to register collector.")
186+
if _has_ray and isinstance(value, ray.actor.ActorHandle):
187+
value.register_collector.remote(self)
188+
else:
189+
value.register_collector(self)
190+
if value.collector is not self:
191+
raise RuntimeError("Failed to register collector.")
180192
self._remote_weight_updater = value
181193

182194
def _get_policy_and_device(
@@ -308,7 +320,10 @@ def update_policy_weights_(
308320
if self.local_weight_updater is not None:
309321
self.local_weight_updater(policy_weights, **kwargs)
310322
if self.remote_weight_updater is not None:
311-
self.remote_weight_updater(policy_weights, worker_ids=worker_ids, **kwargs)
323+
if _has_ray and isinstance(self.remote_weight_updater, ray.actor.ActorHandle):
324+
ray.get(self.remote_weight_updater.__call__.remote(policy_weights, worker_ids=worker_ids, **kwargs))
325+
else:
326+
self.remote_weight_updater(policy_weights, worker_ids=worker_ids, **kwargs)
312327
elif worker_ids is not None:
313328
raise TypeError("worker_ids was passed but remote_weight_updater was None.")
314329

@@ -860,7 +875,7 @@ def __init__(
860875

861876
self.local_weight_updater = local_weight_updater
862877
self.remote_weight_updater = remote_weight_updater
863-
878+
864879
@property
865880
def _traj_pool(self):
866881
pool = getattr(self, "_traj_pool_val", None)

torchrl/collectors/distributed/ray.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,7 @@ def stop_remote_collectors(self):
646646
) # This will interrupt any running tasks on the actor, causing them to fail immediately
647647

648648
def iterator(self):
649+
print(f"{self._sync=}")
649650
def proc(data):
650651
if self.split_trajs:
651652
data = split_trajectories(data)
@@ -759,8 +760,9 @@ def _async_iterator(self) -> Iterator[TensorDictBase]:
759760
yield out_td
760761

761762
if self.update_after_each_batch or self.max_weight_update_interval > -1:
762-
self.update_policy_weights_(worker_ids=collector_index + 1)
763+
self.update_policy_weights_(worker_ids=collector_index)
763764

765+
print("done updating policy weights")
764766
# Schedule a new collection task
765767
future = collector.next.remote()
766768
pending_tasks[future] = collector_index

0 commit comments

Comments
 (0)