Skip to content

Commit db5557d

Browse files
authored
Call remote LLM Engine for model_setup and timer_summary. (#191)
1 parent 55fdfb3 commit db5557d

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

chatlearn/runtime/dist_actor.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -246,28 +246,30 @@ def create_engine_actor(self, num_gpus, placement_group, group_index):
246246
self.vllm_engine = self._create_actor(self.model.__class__, num_gpus, placement_group, group_index)
247247
self.model.engine = self.vllm_engine
248248

249+
def call_vllm_engine_remote_funcs(self, func_name, *args, **kwargs):
250+
"""
251+
Call remote functions for vllm_engine.
252+
"""
253+
results = []
254+
res = self.call_actor_remote_func(self.vllm_engine, func_name, *args, **kwargs)
255+
results.append(res)
256+
return results
257+
249258
def add_remote_func(self):
250259
for func_name, _ in inspect.getmembers(self.master):
251260
# ray.actor.ActorMethod
252-
if func_name.startswith('_') or func_name in ["timer_summary", "peak_memory", "model_setup"]:
261+
if func_name.startswith('_') or func_name in ["peak_memory"]:
253262
continue
254-
dist_call = partial(self.call_remote_funcs, func_name)
263+
if func_name in ["timer_summary", "model_setup"]:
264+
dist_call = partial(self.call_vllm_engine_remote_funcs, func_name)
265+
else: # needed to check for other call_funs.
266+
dist_call = partial(self.call_remote_funcs, func_name)
255267
setattr(self, func_name, dist_call)
256268

257-
def model_setup(self):
258-
return [self.vllm_engine.model_setup.remote()]
259-
260269
@property
261270
def master(self):
262271
return self.vllm_engine
263272

264-
def timer_summary(self, e2e_cost=None):
265-
"""
266-
:meta private:
267-
"""
268-
if self.model._timers:
269-
return self.model._timers.log(e2e_cost=e2e_cost)
270-
271273
def peak_memory(self):
272274
return self.model.peak_memory()
273275

0 commit comments

Comments
 (0)