Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit d37f0bb

Browse files
committedMar 22, 2025
v0 param server (using collectives not object store)
ghstack-source-id: 74de8e0ef2fe059390e009332daeb688656a11fa Pull Request resolved: #2865
1 parent 7df8317 commit d37f0bb

File tree

4 files changed

+462
-5
lines changed

4 files changed

+462
-5
lines changed
 

‎param_server_weight_updater.py

+263
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
import ray
2+
3+
from argparse import ArgumentParser
4+
from functools import partial
5+
6+
import torch
7+
from datasets import load_dataset
8+
from tensordict import TensorDict
9+
from torch.utils.data import DataLoader
10+
from torchrl.collectors.weight_update import RayRemoteWeightUpdater
11+
from transformers import AutoTokenizer, AutoModel
12+
from vllm import LLM
13+
14+
from vllm.utils import get_ip, get_open_port
15+
16+
from torchrl.collectors.distributed import RayCollector
17+
from torchrl.envs import LLMEnv
18+
from torchrl.modules import from_vllm
19+
20+
from torchrl.collectors.vllm_weight_update import vLLMHFLocalWeightUpdater, vLLMRemoteWeightUpdaterBase, WorkerExtension
21+
22+
parser = ArgumentParser()
23+
parser.add_argument("--dataset", type=str, default="gsm8k")
24+
parser.add_argument("--batch_size", type=int, default=4)
25+
parser.add_argument("--epochs", type=int, default=10)
26+
parser.add_argument("--repeats", type=int, default=10)
27+
parser.add_argument("--steps_per_batch", type=int, default=16)
28+
parser.add_argument("--optim_batch_size", type=int, default=4)
29+
30+
31+
def make_policy():
32+
inference_model = LLM(
33+
"facebook/opt-125m",
34+
enforce_eager=True,
35+
# change to worker_extension_cls when available in stable release
36+
worker_cls=WorkerExtension,
37+
)
38+
39+
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
40+
tokenizer.pad_token = tokenizer.eos_token
41+
tokenizer.padding_side = "left"
42+
43+
policy = from_vllm(
44+
inference_model, tokenizer=tokenizer, from_text=False, generate=True, return_log_probs=True, generate_kwargs={"temperature": 0.0})
45+
return policy
46+
47+
48+
def make_env(dataset, batch_size):
49+
dataset = load_dataset(dataset, "main")
50+
train_dataset = dataset["train"]
51+
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
52+
tokenizer.pad_token = tokenizer.eos_token
53+
tokenizer.padding_side = "left"
54+
55+
# Env
56+
dataloader = DataLoader( # noqa: TOR401
57+
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
58+
)
59+
env = LLMEnv.from_dataloader(
60+
dataloader=dataloader,
61+
tokenizer=tokenizer,
62+
str2str=True,
63+
batch_size=(args.batch_size * args.repeats,),
64+
repeats=args.repeats, )
65+
return env
66+
67+
68+
def collate_fn(batch):
69+
batch = torch.stack([TensorDict.from_dict(_batch) for _batch in batch])
70+
batch.rename_key_("question", "text")
71+
return batch
72+
73+
@ray.remote(num_cpus=1, num_gpus=1)
74+
class TrainerActor:
75+
def __init__(self, model, env_vars):
76+
import os
77+
import torch
78+
import torch.distributed
79+
from torch.distributed._composable.fsdp import fully_shard
80+
81+
torch.cuda.set_device(torch.device('cuda', 0))
82+
83+
for var in env_vars:
84+
os.environ[var] = str(env_vars[var])
85+
86+
if not torch.distributed.is_initialized():
87+
torch.distributed.init_process_group(backend="nccl", device_id=torch.device('cuda:0'))
88+
print("initialized process group")
89+
90+
world_size = torch.distributed.get_world_size()
91+
rank = torch.distributed.get_rank()
92+
print(world_size, rank)
93+
self.rank = int(os.environ["RANK"])
94+
self.world_size = int(os.environ["WORLD_SIZE"])
95+
96+
97+
# hold back one rank for the parameter server
98+
self.fsdp_group = torch.distributed.new_group(ranks=list(range(self.world_size - 1)))
99+
self.device_mesh = torch.distributed.device_mesh.DeviceMesh.from_group(self.fsdp_group, device_type="cuda")
100+
101+
self.model = AutoModel.from_pretrained(model).cuda()
102+
103+
fully_shard(self.model, mesh=self.device_mesh)
104+
105+
def register_parameter_server(self, param_server):
106+
assert self.rank == 0
107+
self.param_server = param_server
108+
109+
def send_weights_to_param_server(self):
110+
if self.rank == 0:
111+
ray.get(self.param_server.acquire_state_dict_lock.remote())
112+
self.param_server.receive_from_trainer.remote()
113+
for k, v in self.model.state_dict().items():
114+
replicated_v = v.full_tensor()
115+
if self.rank == 0:
116+
# dst is global rank, can switch to group_dst arg if not 2.5.1
117+
torch.distributed.send(replicated_v, dst=2)
118+
if self.rank == 0:
119+
ray.get(self.param_server.release_state_dict_lock.remote())
120+
121+
def zero_(self):
122+
sd = self.model.state_dict()
123+
for k, v in sd.items():
124+
sd[k] = v.data.zero_()
125+
126+
def train(self):
127+
import time
128+
for _ in range(1):
129+
# actually run train loop
130+
# ...
131+
self.zero_()
132+
torch.distributed.barrier(group=self.fsdp_group)
133+
self.send_weights_to_param_server()
134+
torch.distributed.barrier(group=self.fsdp_group)
135+
136+
137+
@ray.remote(num_cpus=1, num_gpus=1)
138+
class vLLMParameterServer(vLLMRemoteWeightUpdaterBase):
139+
def __init__(self, model, vllm_master_address, vllm_master_port, env_vars):
140+
super().__init__(model, vllm_master_address, vllm_master_port)
141+
import os
142+
import torch
143+
import torch.distributed
144+
145+
torch.cuda.set_device(torch.device('cuda', 0))
146+
147+
for var in env_vars:
148+
os.environ[var] = str(env_vars[var])
149+
150+
if not torch.distributed.is_initialized():
151+
torch.distributed.init_process_group(backend="nccl", device_id=torch.device('cuda:0'))
152+
153+
self.rank = int(os.environ["RANK"])
154+
self.world_size = int(os.environ["WORLD_SIZE"])
155+
assert self.rank == self.world_size - 1
156+
157+
self.fsdp_group = torch.distributed.new_group(ranks=list(range(self.world_size - 1)))
158+
159+
def receive_from_trainer(self):
160+
for k, v in self.state_dict.items():
161+
torch.distributed.recv(v, src=0)
162+
163+
def _skip_update(self, worker_id: int) -> bool:
164+
pass
165+
166+
def check_weights_changed(self):
167+
"""
168+
Check if the weights are updated to 0.
169+
"""
170+
weights_updated = True
171+
for name, p in self.state_dict.items():
172+
weights_updated = weights_updated and torch.allclose(
173+
p, torch.zeros_like(p))
174+
return weights_updated
175+
176+
177+
178+
def _create_trainer_group(
179+
worker_cls,
180+
param_server_cls,
181+
world_size: int,
182+
vllm_master_address,
183+
vllm_master_port,
184+
model,
185+
):
186+
addr, port = get_ip(), get_open_port()
187+
trainer_workers = []
188+
fsdp_world_size = world_size - 1
189+
for i in range(fsdp_world_size):
190+
env_vars = {
191+
"RANK": str(i),
192+
"WORLD_SIZE": world_size,
193+
"MASTER_ADDR": str(addr),
194+
"MASTER_PORT": str(port),
195+
}
196+
worker = worker_cls.remote(model, env_vars)
197+
trainer_workers.append(worker)
198+
199+
env_vars = {
200+
"RANK": str(world_size - 1),
201+
"WORLD_SIZE": world_size,
202+
"MASTER_ADDR": str(addr),
203+
"MASTER_PORT": str(port),
204+
}
205+
parameter_server = param_server_cls.remote(model, vllm_master_address, vllm_master_port, env_vars)
206+
trainer_workers[0].register_parameter_server.remote(parameter_server)
207+
return trainer_workers, parameter_server
208+
209+
210+
if __name__ == "__main__":
211+
args = parser.parse_args()
212+
213+
remote_configs = {
214+
"num_cpus": 1,
215+
"num_gpus": 1,
216+
"memory": 2 * 1024**3,
217+
}
218+
219+
model = "facebook/opt-125m"
220+
221+
ray.init(num_cpus=4, num_gpus=4)
222+
223+
vllm_master_address, vllm_update_port = get_ip(), get_open_port()
224+
225+
trainer_workers, parameter_server = _create_trainer_group(
226+
TrainerActor,
227+
vLLMParameterServer,
228+
3,
229+
vllm_master_address,
230+
vllm_update_port,
231+
model,
232+
)
233+
234+
handles = []
235+
for trainer_worker in trainer_workers:
236+
handles.append(trainer_worker.train.remote())
237+
238+
model_metadata = ray.get(parameter_server.get_model_metadata.remote())
239+
local_weight_updater = vLLMHFLocalWeightUpdater(vllm_master_address, vllm_update_port, model_metadata)
240+
241+
make_env_parsed = partial(make_env, batch_size=args.batch_size, dataset=args.dataset)
242+
collector = RayCollector(
243+
[make_env_parsed],
244+
policy_factory=make_policy,
245+
frames_per_batch=40,
246+
total_frames=200,
247+
remote_configs=remote_configs,
248+
remote_weight_updater=parameter_server,
249+
collector_kwargs={
250+
"local_weight_updater": local_weight_updater,
251+
},
252+
update_after_each_batch=True,
253+
)
254+
print("done collector init")
255+
256+
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
257+
258+
for i, data in enumerate(collector):
259+
print(tokenizer.decode(data["tokens"][0].squeeze()))
260+
print(tokenizer.decode(data["tokens_response"][0].squeeze()))
261+
if i == 1:
262+
break
263+
collector.shutdown()

‎torchrl/collectors/collectors.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ def cudagraph_mark_step_begin():
7676
"""Placeholder for missing cudagraph_mark_step_begin method."""
7777
raise NotImplementedError("cudagraph_mark_step_begin not implemented.")
7878

79+
try:
80+
import ray
81+
from ray.actor import ActorHandle
82+
83+
_has_ray = True
84+
except ImportError as err:
85+
_has_ray = False
86+
RAY_ERR = err
87+
7988

8089
_TIMEOUT = 1.0
8190
INSTANTIATE_TIMEOUT = 20
@@ -174,9 +183,12 @@ def remote_weight_updater(self) -> RemoteWeightUpdaterBase:
174183
@remote_weight_updater.setter
175184
def remote_weight_updater(self, value: RemoteWeightUpdaterBase | None):
176185
if value is not None:
177-
value.register_collector(self)
178-
if value.collector is not self:
179-
raise RuntimeError("Failed to register collector.")
186+
if _has_ray and isinstance(value, ray.actor.ActorHandle):
187+
value.register_collector.remote(self)
188+
else:
189+
value.register_collector(self)
190+
if value.collector is not self:
191+
raise RuntimeError("Failed to register collector.")
180192
self._remote_weight_updater = value
181193

182194
def _get_policy_and_device(
@@ -308,7 +320,10 @@ def update_policy_weights_(
308320
if self.local_weight_updater is not None:
309321
self.local_weight_updater(policy_weights, **kwargs)
310322
if self.remote_weight_updater is not None:
311-
self.remote_weight_updater(policy_weights, worker_ids=worker_ids, **kwargs)
323+
if _has_ray and isinstance(self.remote_weight_updater, ray.actor.ActorHandle):
324+
ray.get(self.remote_weight_updater.__call__.remote(policy_weights, worker_ids=worker_ids, **kwargs))
325+
else:
326+
self.remote_weight_updater(policy_weights, worker_ids=worker_ids, **kwargs)
312327
elif worker_ids is not None:
313328
raise TypeError("worker_ids was passed but remote_weight_updater was None.")
314329

‎torchrl/collectors/distributed/ray.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]:
759759
yield out_td
760760

761761
if self.update_after_each_batch or self.max_weight_update_interval > -1:
762-
self.update_policy_weights_(worker_ids=collector_index + 1)
762+
self.update_policy_weights_(worker_ids=collector_index)
763763

764764
# Schedule a new collection task
765765
future = collector.next.remote()
+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import torch
2+
import threading
3+
4+
from torchrl.collectors.weight_update import RemoteWeightUpdaterBase
5+
from torchrl.collectors.weight_update import LocalWeightUpdaterBase
6+
7+
8+
VLLM_ERR = None
9+
try:
10+
import vllm
11+
from vllm.worker.worker import Worker
12+
13+
_has_vllm = True
14+
except ImportError as err:
15+
_has_vllm = False
16+
VLLM_ERR = err
17+
18+
# These utilities are copied from vLLM's example code.
19+
def stateless_init_process_group(
20+
master_address: str,
21+
master_port: int,
22+
rank: int,
23+
world_size: int,
24+
device: torch.device,
25+
):
26+
"""
27+
vLLM provides `StatelessProcessGroup` to create a process group
28+
without considering the global process group in torch.distributed.
29+
It is recommended to create `StatelessProcessGroup`, and then initialize
30+
the data-plane communication (NCCL) between external (train processes)
31+
and vLLM workers.
32+
"""
33+
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
34+
from vllm.distributed.utils import StatelessProcessGroup
35+
36+
pg = StatelessProcessGroup.create(
37+
host=master_address, port=master_port, rank=rank, world_size=world_size
38+
)
39+
pynccl = PyNcclCommunicator(pg, device=device)
40+
return pynccl
41+
42+
43+
if _has_vllm:
44+
# I should use worker_extension_cls arg and not inherit from worker,
45+
# but that is only available on main and not vLLM 0.7.3
46+
class WorkerExtension(Worker):
47+
"""
48+
The class for vLLM's worker to inherit from.
49+
By defining an extension class, the code can work no matter what is
50+
the underlying worker class. This way, the code can be compatible
51+
with both vLLM V0 and V1.
52+
NOTE: we define this class in a separate module, and the main module
53+
should pass the full qualified name as `worker_extension_cls` argument.
54+
"""
55+
56+
def init_weight_update_group(self, master_address, master_port,
57+
rank_offset, world_size):
58+
from vllm.distributed.parallel_state import get_world_group
59+
rank = get_world_group().rank + rank_offset
60+
self.model_update_group = stateless_init_process_group(
61+
master_address,
62+
master_port,
63+
rank,
64+
world_size,
65+
self.device,
66+
)
67+
68+
def update_weight(self, name, dtype, shape):
69+
weight = torch.empty(shape, dtype=dtype, device="cuda")
70+
self.model_update_group.broadcast(weight,
71+
src=0,
72+
stream=torch.cuda.current_stream())
73+
74+
self.model_runner.model.load_weights(weights=[(name, weight)])
75+
76+
del weight
77+
78+
def check_weights_changed(self):
79+
"""
80+
Check if the weights are updated to 0.
81+
"""
82+
weights_updated = True
83+
for name, p in self.model_runner.model.named_parameters():
84+
weights_updated = weights_updated and torch.allclose(
85+
p, torch.zeros_like(p))
86+
return weights_updated
87+
else:
88+
class WorkerExtension:
89+
pass
90+
91+
92+
class vLLMHFLocalWeightUpdater(LocalWeightUpdaterBase):
93+
def __init__(self, master_address, master_port, model_metadata):
94+
self.master_address = master_address
95+
self.master_port = master_port
96+
self.model_metadata = model_metadata
97+
self.model_update_group = None
98+
99+
def _get_server_weights(self):
100+
return None
101+
102+
def _get_local_weights(self):
103+
# We don't implement this because we let vLLM's update_weights API handle everything
104+
return None
105+
106+
def _maybe_map_weights(self, server_weights, local_weights):
107+
# vLLM update_weights function handles the mapping from huggingface
108+
# so we don't implement this
109+
return None
110+
111+
def _update_local_weights(self, local_weights, mapped_weights):
112+
llm = self.collector.policy["generate"].module
113+
if self.model_update_group is None:
114+
# FIXME: hardcoded
115+
weight_sync_world_size = llm.llm_engine.parallel_config.tensor_parallel_size + 1
116+
llm.collective_rpc(
117+
"init_weight_update_group",
118+
args=(self.master_address, self.master_port, 1, weight_sync_world_size)
119+
)
120+
121+
for k, (dtype, shape) in self.model_metadata.items():
122+
llm.collective_rpc(
123+
"update_weight",
124+
args=(k, dtype, shape)
125+
)
126+
127+
class vLLMRemoteWeightUpdaterBase(RemoteWeightUpdaterBase):
128+
def __init__(self, model, vllm_master_address, vllm_master_port):
129+
super().__init__()
130+
from transformers import AutoModel
131+
self.vllm_master_address = vllm_master_address
132+
self.vllm_master_port = vllm_master_port
133+
self.state_dict = AutoModel.from_pretrained(model).cuda().eval().state_dict()
134+
self.state_dict_lock = threading.Lock()
135+
self.vllm_comm_groups = dict()
136+
# versioning nyi
137+
self.version = 0
138+
139+
def acquire_state_dict_lock(self):
140+
self.state_dict_lock.acquire()
141+
142+
def release_state_dict_lock(self):
143+
self.state_dict_lock.release()
144+
145+
def get_model_metadata(self):
146+
return {k: (v.dtype, v.shape) for k, v in self.state_dict.items()}
147+
148+
def all_worker_ids(self):
149+
return [i for i in range(len(self.collector._remote_collectors))]
150+
151+
def _get_server_weights(self):
152+
return self.state_dict
153+
154+
def _maybe_map_weights(self, server_weights):
155+
return server_weights
156+
157+
def _init_model_update_group(self, worker_id):
158+
# here again, I want to grab the tp size from the vLLM worker... :(
159+
# llm.llm_engine.parallel_config.tensor_parallel_size
160+
vllm_tp_size = 1
161+
weight_sync_world_size = vllm_tp_size + 1
162+
model_update_group = stateless_init_process_group(
163+
self.vllm_master_address,
164+
self.vllm_master_port,
165+
0,
166+
weight_sync_world_size,
167+
torch.device("cuda:0"),
168+
)
169+
self.vllm_comm_groups[worker_id] = model_update_group
170+
171+
def _sync_weights_with_worker(
172+
self, worker_id: int, server_weights
173+
):
174+
self.collector._remote_collectors[worker_id].update_policy_weights_.remote()
175+
if worker_id not in self.vllm_comm_groups:
176+
self._init_model_update_group(worker_id)
177+
with self.state_dict_lock:
178+
for i, k in enumerate(server_weights.keys()):
179+
self.vllm_comm_groups[worker_id].broadcast(server_weights[k], src=0, stream=torch.cuda.current_stream())

0 commit comments

Comments
 (0)
Please sign in to comment.