1
+ import ray
2
+
3
+ from argparse import ArgumentParser
4
+ from functools import partial
5
+
6
+ import torch
7
+ from datasets import load_dataset
8
+ from tensordict import TensorDict
9
+ from torch .utils .data import DataLoader
10
+ from torchrl .collectors .weight_update import RayRemoteWeightUpdater
11
+ from transformers import AutoTokenizer , AutoModel
12
+ from vllm import LLM
13
+
14
+ from vllm .utils import get_ip , get_open_port
15
+
16
+ from torchrl .collectors .distributed import RayCollector
17
+ from torchrl .envs import LLMEnv
18
+ from torchrl .modules import from_vllm
19
+
20
+ from torchrl .collectors .vllm_weight_update import vLLMHFLocalWeightUpdater , vLLMRemoteWeightUpdaterBase , WorkerExtension
21
+
22
+ parser = ArgumentParser ()
23
+ parser .add_argument ("--dataset" , type = str , default = "gsm8k" )
24
+ parser .add_argument ("--batch_size" , type = int , default = 4 )
25
+ parser .add_argument ("--epochs" , type = int , default = 10 )
26
+ parser .add_argument ("--repeats" , type = int , default = 10 )
27
+ parser .add_argument ("--steps_per_batch" , type = int , default = 16 )
28
+ parser .add_argument ("--optim_batch_size" , type = int , default = 4 )
29
+
30
+
31
+ def make_policy ():
32
+ inference_model = LLM (
33
+ "facebook/opt-125m" ,
34
+ enforce_eager = True ,
35
+ # change to worker_extension_cls when available in stable release
36
+ worker_cls = WorkerExtension ,
37
+ )
38
+
39
+ tokenizer = AutoTokenizer .from_pretrained ("facebook/opt-125m" )
40
+ tokenizer .pad_token = tokenizer .eos_token
41
+ tokenizer .padding_side = "left"
42
+
43
+ policy = from_vllm (
44
+ inference_model , tokenizer = tokenizer , from_text = False , generate = True , return_log_probs = True , generate_kwargs = {"temperature" : 0.0 })
45
+ return policy
46
+
47
+
48
+ def make_env (dataset , batch_size ):
49
+ dataset = load_dataset (dataset , "main" )
50
+ train_dataset = dataset ["train" ]
51
+ tokenizer = AutoTokenizer .from_pretrained ("facebook/opt-125m" )
52
+ tokenizer .pad_token = tokenizer .eos_token
53
+ tokenizer .padding_side = "left"
54
+
55
+ # Env
56
+ dataloader = DataLoader ( # noqa: TOR401
57
+ train_dataset , batch_size = batch_size , shuffle = True , collate_fn = collate_fn
58
+ )
59
+ env = LLMEnv .from_dataloader (
60
+ dataloader = dataloader ,
61
+ tokenizer = tokenizer ,
62
+ str2str = True ,
63
+ batch_size = (args .batch_size * args .repeats ,),
64
+ repeats = args .repeats , )
65
+ return env
66
+
67
+
68
+ def collate_fn (batch ):
69
+ batch = torch .stack ([TensorDict .from_dict (_batch ) for _batch in batch ])
70
+ batch .rename_key_ ("question" , "text" )
71
+ return batch
72
+
73
+ @ray .remote (num_cpus = 1 , num_gpus = 1 )
74
+ class TrainerActor :
75
+ def __init__ (self , model , env_vars ):
76
+ import os
77
+ import torch
78
+ import torch .distributed
79
+ from torch .distributed ._composable .fsdp import fully_shard
80
+
81
+ torch .cuda .set_device (torch .device ('cuda' , 0 ))
82
+
83
+ for var in env_vars :
84
+ os .environ [var ] = str (env_vars [var ])
85
+
86
+ if not torch .distributed .is_initialized ():
87
+ torch .distributed .init_process_group (backend = "nccl" , device_id = torch .device ('cuda:0' ))
88
+ print ("initialized process group" )
89
+
90
+ world_size = torch .distributed .get_world_size ()
91
+ rank = torch .distributed .get_rank ()
92
+ print (world_size , rank )
93
+ self .rank = int (os .environ ["RANK" ])
94
+ self .world_size = int (os .environ ["WORLD_SIZE" ])
95
+
96
+
97
+ # hold back one rank for the parameter server
98
+ self .fsdp_group = torch .distributed .new_group (ranks = list (range (self .world_size - 1 )))
99
+ self .device_mesh = torch .distributed .device_mesh .DeviceMesh .from_group (self .fsdp_group , device_type = "cuda" )
100
+
101
+ self .model = AutoModel .from_pretrained (model ).cuda ()
102
+
103
+ fully_shard (self .model , mesh = self .device_mesh )
104
+
105
+ def register_parameter_server (self , param_server ):
106
+ assert self .rank == 0
107
+ self .param_server = param_server
108
+
109
+ def send_weights_to_param_server (self ):
110
+ if self .rank == 0 :
111
+ ray .get (self .param_server .acquire_state_dict_lock .remote ())
112
+ self .param_server .receive_from_trainer .remote ()
113
+ for k , v in self .model .state_dict ().items ():
114
+ replicated_v = v .full_tensor ()
115
+ if self .rank == 0 :
116
+ # dst is global rank, can switch to group_dst arg if not 2.5.1
117
+ torch .distributed .send (replicated_v , dst = 2 )
118
+ if self .rank == 0 :
119
+ ray .get (self .param_server .release_state_dict_lock .remote ())
120
+
121
+ def zero_ (self ):
122
+ sd = self .model .state_dict ()
123
+ for k , v in sd .items ():
124
+ sd [k ] = v .data .zero_ ()
125
+
126
+ def train (self ):
127
+ import time
128
+ for _ in range (1 ):
129
+ # actually run train loop
130
+ # ...
131
+ self .zero_ ()
132
+ torch .distributed .barrier (group = self .fsdp_group )
133
+ self .send_weights_to_param_server ()
134
+ torch .distributed .barrier (group = self .fsdp_group )
135
+
136
+
137
+ @ray .remote (num_cpus = 1 , num_gpus = 1 )
138
+ class vLLMParameterServer (vLLMRemoteWeightUpdaterBase ):
139
+ def __init__ (self , model , vllm_master_address , vllm_master_port , env_vars ):
140
+ super ().__init__ (model , vllm_master_address , vllm_master_port )
141
+ import os
142
+ import torch
143
+ import torch .distributed
144
+
145
+ torch .cuda .set_device (torch .device ('cuda' , 0 ))
146
+
147
+ for var in env_vars :
148
+ os .environ [var ] = str (env_vars [var ])
149
+
150
+ if not torch .distributed .is_initialized ():
151
+ torch .distributed .init_process_group (backend = "nccl" , device_id = torch .device ('cuda:0' ))
152
+
153
+ self .rank = int (os .environ ["RANK" ])
154
+ self .world_size = int (os .environ ["WORLD_SIZE" ])
155
+ assert self .rank == self .world_size - 1
156
+
157
+ self .fsdp_group = torch .distributed .new_group (ranks = list (range (self .world_size - 1 )))
158
+
159
+ def receive_from_trainer (self ):
160
+ for k , v in self .state_dict .items ():
161
+ torch .distributed .recv (v , src = 0 )
162
+
163
+ def _skip_update (self , worker_id : int ) -> bool :
164
+ pass
165
+
166
+ def check_weights_changed (self ):
167
+ """
168
+ Check if the weights are updated to 0.
169
+ """
170
+ weights_updated = True
171
+ for name , p in self .state_dict .items ():
172
+ weights_updated = weights_updated and torch .allclose (
173
+ p , torch .zeros_like (p ))
174
+ return weights_updated
175
+
176
+
177
+
178
+ def _create_trainer_group (
179
+ worker_cls ,
180
+ param_server_cls ,
181
+ world_size : int ,
182
+ vllm_master_address ,
183
+ vllm_master_port ,
184
+ model ,
185
+ ):
186
+ addr , port = get_ip (), get_open_port ()
187
+ trainer_workers = []
188
+ fsdp_world_size = world_size - 1
189
+ for i in range (fsdp_world_size ):
190
+ env_vars = {
191
+ "RANK" : str (i ),
192
+ "WORLD_SIZE" : world_size ,
193
+ "MASTER_ADDR" : str (addr ),
194
+ "MASTER_PORT" : str (port ),
195
+ }
196
+ worker = worker_cls .remote (model , env_vars )
197
+ trainer_workers .append (worker )
198
+
199
+ env_vars = {
200
+ "RANK" : str (world_size - 1 ),
201
+ "WORLD_SIZE" : world_size ,
202
+ "MASTER_ADDR" : str (addr ),
203
+ "MASTER_PORT" : str (port ),
204
+ }
205
+ parameter_server = param_server_cls .remote (model , vllm_master_address , vllm_master_port , env_vars )
206
+ trainer_workers [0 ].register_parameter_server .remote (parameter_server )
207
+ return trainer_workers , parameter_server
208
+
209
+
210
+ if __name__ == "__main__" :
211
+ args = parser .parse_args ()
212
+
213
+ remote_configs = {
214
+ "num_cpus" : 1 ,
215
+ "num_gpus" : 1 ,
216
+ "memory" : 2 * 1024 ** 3 ,
217
+ }
218
+
219
+ model = "facebook/opt-125m"
220
+
221
+ ray .init (num_cpus = 4 , num_gpus = 4 )
222
+
223
+ vllm_master_address , vllm_update_port = get_ip (), get_open_port ()
224
+
225
+ trainer_workers , parameter_server = _create_trainer_group (
226
+ TrainerActor ,
227
+ vLLMParameterServer ,
228
+ 3 ,
229
+ vllm_master_address ,
230
+ vllm_update_port ,
231
+ model ,
232
+ )
233
+
234
+ handles = []
235
+ for trainer_worker in trainer_workers :
236
+ handles .append (trainer_worker .train .remote ())
237
+
238
+ model_metadata = ray .get (parameter_server .get_model_metadata .remote ())
239
+ local_weight_updater = vLLMHFLocalWeightUpdater (vllm_master_address , vllm_update_port , model_metadata )
240
+
241
+ make_env_parsed = partial (make_env , batch_size = args .batch_size , dataset = args .dataset )
242
+ collector = RayCollector (
243
+ [make_env_parsed ],
244
+ policy_factory = make_policy ,
245
+ frames_per_batch = 40 ,
246
+ total_frames = 200 ,
247
+ remote_configs = remote_configs ,
248
+ remote_weight_updater = parameter_server ,
249
+ collector_kwargs = {
250
+ "local_weight_updater" : local_weight_updater ,
251
+ },
252
+ update_after_each_batch = True ,
253
+ )
254
+ print ("done collector init" )
255
+
256
+ tokenizer = AutoTokenizer .from_pretrained ("facebook/opt-125m" )
257
+
258
+ for i , data in enumerate (collector ):
259
+ print (tokenizer .decode (data ["tokens" ][0 ].squeeze ()))
260
+ print (tokenizer .decode (data ["tokens_response" ][0 ].squeeze ()))
261
+ if i == 1 :
262
+ break
263
+ collector .shutdown ()
0 commit comments