Skip to content

Commit e04a1d9

Browse files
committed
[Feature] vllm wrapper
ghstack-source-id: 967ca68 Pull Request resolved: #2830
1 parent c2081b2 commit e04a1d9

14 files changed

+847
-73
lines changed

.github/unittest/linux_libs/scripts_rlhf/environment.yml .github/unittest/linux_libs/scripts_llm/environment.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ dependencies:
1717
- pyyaml
1818
- scipy
1919
- hydra-core
20-
- transformers<4.42.0
20+
- transformers
2121
- datasets
22+
- vllm

.github/unittest/linux_libs/scripts_rlhf/install.sh .github/unittest/linux_libs/scripts_llm/install.sh

+18-17
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,24 @@ fi
2626
# submodules
2727
git submodule sync && git submodule update --init --recursive
2828

29-
printf "Installing PyTorch with cu128"
30-
if [[ "$TORCH_VERSION" == "nightly" ]]; then
31-
if [ "${CU_VERSION:-}" == cpu ] ; then
32-
pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
33-
else
34-
pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
35-
fi
36-
elif [[ "$TORCH_VERSION" == "stable" ]]; then
37-
if [ "${CU_VERSION:-}" == cpu ] ; then
38-
pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
39-
else
40-
pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
41-
fi
42-
else
43-
printf "Failed to install pytorch"
44-
exit 1
45-
fi
29+
# We skip pytorch install due to vllm requirements
30+
#printf "Installing PyTorch with cu128"
31+
#if [[ "$TORCH_VERSION" == "nightly" ]]; then
32+
# if [ "${CU_VERSION:-}" == cpu ] ; then
33+
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
34+
# else
35+
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
36+
# fi
37+
#elif [[ "$TORCH_VERSION" == "stable" ]]; then
38+
# if [ "${CU_VERSION:-}" == cpu ] ; then
39+
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
40+
# else
41+
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
42+
# fi
43+
#else
44+
# printf "Failed to install pytorch"
45+
# exit 1
46+
#fi
4647

4748
# install tensordict
4849
if [[ "$RELEASE" == 0 ]]; then

.github/unittest/linux_libs/scripts_rlhf/run_test.sh .github/unittest/linux_libs/scripts_llm/run_test.sh

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ python -c "import transformers, datasets"
2424

2525
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips
2626

27+
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_actors.py -k llm --instafail -v --durations 200 --capture no --error-for-skips --runslow
28+
2729
python .github/unittest/helpers/coverage_run_parallel.py examples/rlhf/train_rlhf.py \
2830
sys.device=cuda:0 sys.ref_device=cuda:0 \
2931
model.name_or_path=gpt2 train.max_epochs=2 \

.github/workflows/test-linux-rlhf.yml .github/workflows/test-linux-llm.yml

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: RLHF Tests on Linux
1+
name: LLM Tests on Linux
22

33
on:
44
pull_request:
@@ -50,7 +50,7 @@ jobs:
5050
export TF_CPP_MIN_LOG_LEVEL=0
5151
export TD_GET_DEFAULTS_TO_NONE=1
5252
53-
bash .github/unittest/linux_libs/scripts_rlhf/setup_env.sh
54-
bash .github/unittest/linux_libs/scripts_rlhf/install.sh
55-
bash .github/unittest/linux_libs/scripts_rlhf/run_test.sh
56-
bash .github/unittest/linux_libs/scripts_rlhf/post_process.sh
53+
bash .github/unittest/linux_libs/scripts_llm/setup_env.sh
54+
bash .github/unittest/linux_libs/scripts_llm/install.sh
55+
bash .github/unittest/linux_libs/scripts_llm/run_test.sh
56+
bash .github/unittest/linux_libs/scripts_llm/post_process.sh

test/test_actors.py

+232-21
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,27 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import argparse
6+
import importlib.util
67
import os
78

89
import pytest
910
import torch
10-
1111
from tensordict import NonTensorStack, TensorDict
1212
from tensordict.nn import CompositeDistribution, TensorDictModule
1313
from tensordict.nn.distributions import NormalParamExtractor
1414

1515
from torch import distributions as dist, nn
1616
from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
17+
from torchrl.data.llm import LLMData
1718
from torchrl.data.llm.dataset import _has_transformers
18-
from torchrl.modules import from_hf_transformers, MLP, SafeModule, TanhDelta, TanhNormal
19+
from torchrl.modules import (
20+
from_hf_transformers,
21+
from_vllm,
22+
MLP,
23+
SafeModule,
24+
TanhDelta,
25+
TanhNormal,
26+
)
1927
from torchrl.modules.tensordict_module.actors import (
2028
_process_action_space_spec,
2129
ActorValueOperator,
@@ -37,6 +45,8 @@
3745
from _utils_internal import get_default_devices
3846
from mocking_classes import NestedCountingEnv
3947

48+
_has_vllm = importlib.util.find_spec("vllm") is not None
49+
4050

4151
@pytest.mark.parametrize(
4252
"log_prob_key",
@@ -908,52 +918,253 @@ def test_lmhead_actorvalueoperator(device):
908918

909919

910920
@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
911-
class TestTransformerActor:
921+
@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies")
922+
class TestLLMActor:
912923
@pytest.mark.parametrize(
913-
"from_text, generate, tokens, attention_mask",
924+
"from_text, generate, return_log_probs, tokens, attention_mask",
914925
[
915-
(True, True, None, None),
916-
(True, False, None, None),
926+
(True, True, True, None, None),
927+
(True, True, False, None, None),
928+
(True, False, None, None, None),
929+
(
930+
False,
931+
True,
932+
True,
933+
torch.randint(1024, (1, 10)),
934+
torch.ones(1, 10, dtype=torch.int64),
935+
),
936+
(False, True, True, torch.randint(1024, (1, 10)), None),
917937
(
918938
False,
919939
True,
940+
False,
920941
torch.randint(1024, (1, 10)),
921942
torch.ones(1, 10, dtype=torch.int64),
922943
),
923-
(False, True, torch.randint(1024, (1, 10)), None),
944+
(False, True, False, torch.randint(1024, (1, 10)), None),
924945
],
925946
)
926-
def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask):
927-
from torchrl.data.llm import LLMData
947+
def test_from_hf_transformers(
948+
self, from_text, generate, return_log_probs, tokens, attention_mask
949+
):
928950
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
929951

952+
model_name = "distilbert-base-uncased" # or "minilm" or "albert-tiny"
953+
# Load the model and tokenizer
954+
# model = AutoModel.from_pretrained(model_name)
955+
# tokenizer = AutoTokenizer.from_pretrained(model_name)
956+
930957
tokenizer = AutoTokenizer.from_pretrained("gpt2")
931-
tokenizer.pad_token = tokenizer.eos_token
932958
model = GPT2LMHeadModel(GPT2Config())
959+
960+
tokenizer.pad_token = tokenizer.eos_token
933961
tokenizer.padding_side = "left"
962+
934963
m = from_hf_transformers(
935-
model, tokenizer=tokenizer, from_text=from_text, generate=generate
964+
model,
965+
tokenizer=tokenizer,
966+
from_text=from_text,
967+
generate=generate,
968+
return_log_probs=return_log_probs,
969+
)
970+
self._run_check(
971+
m,
972+
tokens,
973+
attention_mask,
974+
generate,
975+
return_log_probs,
976+
from_text,
977+
has_logits=True,
978+
)
979+
980+
@pytest.mark.parametrize(
981+
"from_text, generate, return_log_probs, tokens, attention_mask",
982+
[
983+
(True, True, True, None, None),
984+
(True, True, False, None, None),
985+
(True, False, None, None, None),
986+
(
987+
False,
988+
True,
989+
True,
990+
torch.randint(1024, (1, 10)),
991+
torch.ones(1, 10, dtype=torch.int64),
992+
),
993+
(False, True, True, torch.randint(1024, (1, 10)), None),
994+
(
995+
False,
996+
True,
997+
False,
998+
torch.randint(1024, (1, 10)),
999+
torch.ones(1, 10, dtype=torch.int64),
1000+
),
1001+
(False, True, False, torch.randint(1024, (1, 10)), None),
1002+
],
1003+
)
1004+
def test_from_vllm(
1005+
self, from_text, generate, return_log_probs, tokens, attention_mask
1006+
):
1007+
from vllm import LLM
1008+
1009+
model = LLM(model="facebook/opt-125m")
1010+
m = from_vllm(
1011+
model,
1012+
from_text=from_text,
1013+
generate=generate,
1014+
return_log_probs=return_log_probs,
1015+
)
1016+
self._run_check(
1017+
m,
1018+
tokens,
1019+
attention_mask,
1020+
generate,
1021+
return_log_probs,
1022+
from_text,
1023+
has_logits=False,
9361024
)
1025+
1026+
def _make_data(
1027+
self,
1028+
m,
1029+
tokens,
1030+
attention_mask,
1031+
generate,
1032+
from_text,
1033+
has_logits,
1034+
text_response=None,
1035+
tokens_response=None,
1036+
):
1037+
lp_kwargs = {}
9371038
if from_text:
938-
tdin = LLMData(text=NonTensorStack("a text"), batch_size=1)
1039+
if not generate:
1040+
text_response = (
1041+
NonTensorStack(" and another text that follows")
1042+
if text_response is None
1043+
else text_response
1044+
)
1045+
if not isinstance(text_response, NonTensorStack):
1046+
if isinstance(text_response, list):
1047+
text_response = NonTensorStack(*text_response)
1048+
else:
1049+
text_response = NonTensorStack(text_response)
1050+
lp_kwargs.update({"text_response": text_response})
1051+
tdin = LLMData(text=NonTensorStack("a text"), **lp_kwargs, batch_size=1)
9391052
else:
940-
tdin = LLMData(tokens=tokens, attention_mask=attention_mask, batch_size=1)
1053+
if not generate:
1054+
if tokens_response is None:
1055+
shape_response = tokens.shape
1056+
shape_response = shape_response[:-1] + (shape_response[-1] * 2,)
1057+
tokens_response = torch.randint(1024, shape_response)
1058+
lp_kwargs.update({"tokens_response": tokens_response})
1059+
tdin = LLMData(
1060+
tokens=tokens, attention_mask=attention_mask, **lp_kwargs, batch_size=1
1061+
)
1062+
return tdin
1063+
1064+
def _run_check(
1065+
self,
1066+
m,
1067+
tokens,
1068+
attention_mask,
1069+
generate,
1070+
return_log_probs,
1071+
from_text,
1072+
has_logits,
1073+
):
1074+
tdin = self._make_data(
1075+
m, tokens, attention_mask, generate, from_text, has_logits
1076+
)
1077+
if from_text and generate:
1078+
assert tdin.text_response is None
1079+
elif from_text and not generate:
1080+
assert tdin.text_response is not None
1081+
9411082
td = m(tdin)
9421083
assert td is tdin
9431084
assert isinstance(td, LLMData)
9441085
if from_text and generate:
9451086
assert td.text_response is not None
1087+
if generate and (attention_mask is not None or from_text):
1088+
assert td.attention_mask is not None, (generate, generate, from_text)
9461089
else:
947-
assert td.text_response is None
948-
if attention_mask is not None or from_text:
949-
assert td.attention_mask is not None
950-
else:
951-
assert td.attention_mask is None
1090+
assert td.attention_mask is None, (generate, from_text)
9521091
if not generate:
953-
assert td.text_response is None
954-
assert td.tokens_response is None
1092+
# logprobs are computed on text response of tokens_response
1093+
assert td.text_response is not None or td.tokens_response is not None
9551094
assert td.log_probs is not None
956-
assert td.logits is not None
1095+
if has_logits:
1096+
assert td.logits is not None
1097+
if generate:
1098+
if return_log_probs:
1099+
assert td.log_probs is not None
1100+
assert td.log_probs.shape[-2] == td.tokens_response.shape[-1]
1101+
else:
1102+
assert td.log_probs is None
1103+
1104+
# Test the shapes
1105+
assert td.tokens_response is not None, (generate, has_logits, from_text)
1106+
1107+
# If from text and not generating, the tokens are not returned for now
1108+
if not (from_text and not generate):
1109+
assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1]
1110+
# The convention is that the response only has new tokens
1111+
assert (
1112+
td.tokens_response[..., : td.tokens.shape[-1]]
1113+
!= td.tokens[..., : td.tokens_response.shape[-1]]
1114+
).any(), (generate, from_text)
1115+
1116+
@pytest.mark.parametrize(
1117+
"from_text, tokens, attention_mask",
1118+
[
1119+
(True, None, None),
1120+
(
1121+
False,
1122+
torch.randint(1024, (1, 10)),
1123+
torch.ones(1, 10, dtype=torch.int64),
1124+
),
1125+
(False, torch.randint(1024, (1, 10)), None),
1126+
],
1127+
)
1128+
def test_from_vllm_logprobs(self, from_text, tokens, attention_mask):
1129+
from vllm import LLM
1130+
1131+
model = LLM(model="facebook/opt-125m")
1132+
m_generate = from_vllm(
1133+
model, from_text=from_text, generate=True, return_log_probs=True
1134+
)
1135+
m_logprobs = from_vllm(model, from_text=from_text, generate=False)
1136+
self._check_lps(
1137+
m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False
1138+
)
1139+
1140+
def _check_lps(
1141+
self,
1142+
model_generate,
1143+
model_logprobs,
1144+
tokens,
1145+
attention_mask,
1146+
from_text,
1147+
has_logits,
1148+
):
1149+
# Checks that the log-probs gathered with generate=False equate those with generate=True
1150+
tdin_genetate = self._make_data(
1151+
model_generate, tokens, attention_mask, True, from_text, has_logits
1152+
)
1153+
td_generate = model_generate(tdin_genetate)
1154+
tdin_logprobs = self._make_data(
1155+
model_logprobs,
1156+
tokens,
1157+
attention_mask,
1158+
False,
1159+
from_text,
1160+
has_logits,
1161+
tokens_response=td_generate.tokens_response,
1162+
text_response=td_generate.text_response,
1163+
)
1164+
td_logprobs = model_logprobs(tdin_logprobs)
1165+
torch.testing.assert_close(
1166+
td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2
1167+
)
9571168

9581169

9591170
if __name__ == "__main__":

0 commit comments

Comments
 (0)