Skip to content

Commit

Permalink
Support multi episode with vllm_v2. (#187)
Browse files Browse the repository at this point in the history
* Support multi episode.
  • Loading branch information
adoda authored Dec 24, 2024
1 parent bd75b69 commit 55fdfb3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
3 changes: 1 addition & 2 deletions chatlearn/schedule/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ def sync_parameters(self, episode_offset=0, requires_grad=None, validate=False):
def set_func_decorator(self, model):
if is_decorated(model.name):
return
# decorate async method here will raise cannot serialize coroutine object error
call_funcs = model.call_funcs if not isinstance(model, VLLMModuleV2) else []
call_funcs = model.call_funcs

model_cls = model.__class__
for func_name in call_funcs:
Expand Down
7 changes: 4 additions & 3 deletions examples/megatron/models/vllm_policy_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,21 @@
"""vllm policy inference"""

import copy
import os
import random

import torch
import torch.nn.functional as F

from chatlearn.models.vllm import is_vllm_v2
from examples.megatron.data.prompt_dataset import VLLMPromptPipeline
from .utils import get_loss_mask

if os.environ.get("ENABLE_VLLM_V2"):
# pylint: disable=ungrouped-imports
if is_vllm_v2():
from chatlearn import VLLMModuleV2 as VLLMModule
else:
from chatlearn import VLLMModule

# pylint: enable=ungrouped-imports


class VLLMPolicyInference(VLLMModule):
Expand Down

0 comments on commit 55fdfb3

Please sign in to comment.