Skip to content

Commit

Permalink
Call remote LLM Engine for model_setup and timer_summary. (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
adoda authored Dec 25, 2024
1 parent 55fdfb3 commit db5557d
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions chatlearn/runtime/dist_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,28 +246,30 @@ def create_engine_actor(self, num_gpus, placement_group, group_index):
self.vllm_engine = self._create_actor(self.model.__class__, num_gpus, placement_group, group_index)
self.model.engine = self.vllm_engine

def call_vllm_engine_remote_funcs(self, func_name, *args, **kwargs):
"""
Call remote functions for vllm_engine.
"""
results = []
res = self.call_actor_remote_func(self.vllm_engine, func_name, *args, **kwargs)
results.append(res)
return results

def add_remote_func(self):
for func_name, _ in inspect.getmembers(self.master):
# ray.actor.ActorMethod
if func_name.startswith('_') or func_name in ["timer_summary", "peak_memory", "model_setup"]:
if func_name.startswith('_') or func_name in ["peak_memory"]:
continue
dist_call = partial(self.call_remote_funcs, func_name)
if func_name in ["timer_summary", "model_setup"]:
dist_call = partial(self.call_vllm_engine_remote_funcs, func_name)
else: # needed to check for other call_funs.
dist_call = partial(self.call_remote_funcs, func_name)
setattr(self, func_name, dist_call)

def model_setup(self):
return [self.vllm_engine.model_setup.remote()]

@property
def master(self):
return self.vllm_engine

def timer_summary(self, e2e_cost=None):
"""
:meta private:
"""
if self.model._timers:
return self.model._timers.log(e2e_cost=e2e_cost)

def peak_memory(self):
return self.model.peak_memory()

Expand Down

0 comments on commit db5557d

Please sign in to comment.