|
3 | 3 | # This source code is licensed under the MIT license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 | 5 | import argparse
|
| 6 | +import importlib.util |
6 | 7 | import os
|
7 | 8 |
|
8 | 9 | import pytest
|
9 | 10 | import torch
|
10 |
| - |
11 | 11 | from tensordict import NonTensorStack, TensorDict
|
12 | 12 | from tensordict.nn import CompositeDistribution, TensorDictModule
|
13 | 13 | from tensordict.nn.distributions import NormalParamExtractor
|
14 | 14 |
|
15 | 15 | from torch import distributions as dist, nn
|
16 | 16 | from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
|
| 17 | +from torchrl.data.llm import LLMData |
17 | 18 | from torchrl.data.llm.dataset import _has_transformers
|
18 |
| -from torchrl.modules import from_hf_transformers, MLP, SafeModule, TanhDelta, TanhNormal |
| 19 | +from torchrl.modules import ( |
| 20 | + from_hf_transformers, |
| 21 | + from_vllm, |
| 22 | + MLP, |
| 23 | + SafeModule, |
| 24 | + TanhDelta, |
| 25 | + TanhNormal, |
| 26 | +) |
19 | 27 | from torchrl.modules.tensordict_module.actors import (
|
20 | 28 | _process_action_space_spec,
|
21 | 29 | ActorValueOperator,
|
|
37 | 45 | from _utils_internal import get_default_devices
|
38 | 46 | from mocking_classes import NestedCountingEnv
|
39 | 47 |
|
| 48 | +_has_vllm = importlib.util.find_spec("vllm") is not None |
| 49 | + |
40 | 50 |
|
41 | 51 | @pytest.mark.parametrize(
|
42 | 52 | "log_prob_key",
|
@@ -908,52 +918,253 @@ def test_lmhead_actorvalueoperator(device):
|
908 | 918 |
|
909 | 919 |
|
910 | 920 | @pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
|
911 |
| -class TestTransformerActor: |
| 921 | +@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies") |
| 922 | +class TestLLMActor: |
912 | 923 | @pytest.mark.parametrize(
|
913 |
| - "from_text, generate, tokens, attention_mask", |
| 924 | + "from_text, generate, return_log_probs, tokens, attention_mask", |
914 | 925 | [
|
915 |
| - (True, True, None, None), |
916 |
| - (True, False, None, None), |
| 926 | + (True, True, True, None, None), |
| 927 | + (True, True, False, None, None), |
| 928 | + (True, False, None, None, None), |
| 929 | + ( |
| 930 | + False, |
| 931 | + True, |
| 932 | + True, |
| 933 | + torch.randint(1024, (1, 10)), |
| 934 | + torch.ones(1, 10, dtype=torch.int64), |
| 935 | + ), |
| 936 | + (False, True, True, torch.randint(1024, (1, 10)), None), |
917 | 937 | (
|
918 | 938 | False,
|
919 | 939 | True,
|
| 940 | + False, |
920 | 941 | torch.randint(1024, (1, 10)),
|
921 | 942 | torch.ones(1, 10, dtype=torch.int64),
|
922 | 943 | ),
|
923 |
| - (False, True, torch.randint(1024, (1, 10)), None), |
| 944 | + (False, True, False, torch.randint(1024, (1, 10)), None), |
924 | 945 | ],
|
925 | 946 | )
|
926 |
| - def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask): |
927 |
| - from torchrl.data.llm import LLMData |
| 947 | + def test_from_hf_transformers( |
| 948 | + self, from_text, generate, return_log_probs, tokens, attention_mask |
| 949 | + ): |
928 | 950 | from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
|
929 | 951 |
|
| 952 | + model_name = "distilbert-base-uncased" # or "minilm" or "albert-tiny" |
| 953 | + # Load the model and tokenizer |
| 954 | + # model = AutoModel.from_pretrained(model_name) |
| 955 | + # tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 956 | + |
930 | 957 | tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
931 |
| - tokenizer.pad_token = tokenizer.eos_token |
932 | 958 | model = GPT2LMHeadModel(GPT2Config())
|
| 959 | + |
| 960 | + tokenizer.pad_token = tokenizer.eos_token |
933 | 961 | tokenizer.padding_side = "left"
|
| 962 | + |
934 | 963 | m = from_hf_transformers(
|
935 |
| - model, tokenizer=tokenizer, from_text=from_text, generate=generate |
| 964 | + model, |
| 965 | + tokenizer=tokenizer, |
| 966 | + from_text=from_text, |
| 967 | + generate=generate, |
| 968 | + return_log_probs=return_log_probs, |
| 969 | + ) |
| 970 | + self._run_check( |
| 971 | + m, |
| 972 | + tokens, |
| 973 | + attention_mask, |
| 974 | + generate, |
| 975 | + return_log_probs, |
| 976 | + from_text, |
| 977 | + has_logits=True, |
| 978 | + ) |
| 979 | + |
| 980 | + @pytest.mark.parametrize( |
| 981 | + "from_text, generate, return_log_probs, tokens, attention_mask", |
| 982 | + [ |
| 983 | + (True, True, True, None, None), |
| 984 | + (True, True, False, None, None), |
| 985 | + (True, False, None, None, None), |
| 986 | + ( |
| 987 | + False, |
| 988 | + True, |
| 989 | + True, |
| 990 | + torch.randint(1024, (1, 10)), |
| 991 | + torch.ones(1, 10, dtype=torch.int64), |
| 992 | + ), |
| 993 | + (False, True, True, torch.randint(1024, (1, 10)), None), |
| 994 | + ( |
| 995 | + False, |
| 996 | + True, |
| 997 | + False, |
| 998 | + torch.randint(1024, (1, 10)), |
| 999 | + torch.ones(1, 10, dtype=torch.int64), |
| 1000 | + ), |
| 1001 | + (False, True, False, torch.randint(1024, (1, 10)), None), |
| 1002 | + ], |
| 1003 | + ) |
| 1004 | + def test_from_vllm( |
| 1005 | + self, from_text, generate, return_log_probs, tokens, attention_mask |
| 1006 | + ): |
| 1007 | + from vllm import LLM |
| 1008 | + |
| 1009 | + model = LLM(model="facebook/opt-125m") |
| 1010 | + m = from_vllm( |
| 1011 | + model, |
| 1012 | + from_text=from_text, |
| 1013 | + generate=generate, |
| 1014 | + return_log_probs=return_log_probs, |
| 1015 | + ) |
| 1016 | + self._run_check( |
| 1017 | + m, |
| 1018 | + tokens, |
| 1019 | + attention_mask, |
| 1020 | + generate, |
| 1021 | + return_log_probs, |
| 1022 | + from_text, |
| 1023 | + has_logits=False, |
936 | 1024 | )
|
| 1025 | + |
| 1026 | + def _make_data( |
| 1027 | + self, |
| 1028 | + m, |
| 1029 | + tokens, |
| 1030 | + attention_mask, |
| 1031 | + generate, |
| 1032 | + from_text, |
| 1033 | + has_logits, |
| 1034 | + text_response=None, |
| 1035 | + tokens_response=None, |
| 1036 | + ): |
| 1037 | + lp_kwargs = {} |
937 | 1038 | if from_text:
|
938 |
| - tdin = LLMData(text=NonTensorStack("a text"), batch_size=1) |
| 1039 | + if not generate: |
| 1040 | + text_response = ( |
| 1041 | + NonTensorStack(" and another text that follows") |
| 1042 | + if text_response is None |
| 1043 | + else text_response |
| 1044 | + ) |
| 1045 | + if not isinstance(text_response, NonTensorStack): |
| 1046 | + if isinstance(text_response, list): |
| 1047 | + text_response = NonTensorStack(*text_response) |
| 1048 | + else: |
| 1049 | + text_response = NonTensorStack(text_response) |
| 1050 | + lp_kwargs.update({"text_response": text_response}) |
| 1051 | + tdin = LLMData(text=NonTensorStack("a text"), **lp_kwargs, batch_size=1) |
939 | 1052 | else:
|
940 |
| - tdin = LLMData(tokens=tokens, attention_mask=attention_mask, batch_size=1) |
| 1053 | + if not generate: |
| 1054 | + if tokens_response is None: |
| 1055 | + shape_response = tokens.shape |
| 1056 | + shape_response = shape_response[:-1] + (shape_response[-1] * 2,) |
| 1057 | + tokens_response = torch.randint(1024, shape_response) |
| 1058 | + lp_kwargs.update({"tokens_response": tokens_response}) |
| 1059 | + tdin = LLMData( |
| 1060 | + tokens=tokens, attention_mask=attention_mask, **lp_kwargs, batch_size=1 |
| 1061 | + ) |
| 1062 | + return tdin |
| 1063 | + |
| 1064 | + def _run_check( |
| 1065 | + self, |
| 1066 | + m, |
| 1067 | + tokens, |
| 1068 | + attention_mask, |
| 1069 | + generate, |
| 1070 | + return_log_probs, |
| 1071 | + from_text, |
| 1072 | + has_logits, |
| 1073 | + ): |
| 1074 | + tdin = self._make_data( |
| 1075 | + m, tokens, attention_mask, generate, from_text, has_logits |
| 1076 | + ) |
| 1077 | + if from_text and generate: |
| 1078 | + assert tdin.text_response is None |
| 1079 | + elif from_text and not generate: |
| 1080 | + assert tdin.text_response is not None |
| 1081 | + |
941 | 1082 | td = m(tdin)
|
942 | 1083 | assert td is tdin
|
943 | 1084 | assert isinstance(td, LLMData)
|
944 | 1085 | if from_text and generate:
|
945 | 1086 | assert td.text_response is not None
|
| 1087 | + if generate and (attention_mask is not None or from_text): |
| 1088 | + assert td.attention_mask is not None, (generate, generate, from_text) |
946 | 1089 | else:
|
947 |
| - assert td.text_response is None |
948 |
| - if attention_mask is not None or from_text: |
949 |
| - assert td.attention_mask is not None |
950 |
| - else: |
951 |
| - assert td.attention_mask is None |
| 1090 | + assert td.attention_mask is None, (generate, from_text) |
952 | 1091 | if not generate:
|
953 |
| - assert td.text_response is None |
954 |
| - assert td.tokens_response is None |
| 1092 | + # logprobs are computed on text response of tokens_response |
| 1093 | + assert td.text_response is not None or td.tokens_response is not None |
955 | 1094 | assert td.log_probs is not None
|
956 |
| - assert td.logits is not None |
| 1095 | + if has_logits: |
| 1096 | + assert td.logits is not None |
| 1097 | + if generate: |
| 1098 | + if return_log_probs: |
| 1099 | + assert td.log_probs is not None |
| 1100 | + assert td.log_probs.shape[-2] == td.tokens_response.shape[-1] |
| 1101 | + else: |
| 1102 | + assert td.log_probs is None |
| 1103 | + |
| 1104 | + # Test the shapes |
| 1105 | + assert td.tokens_response is not None, (generate, has_logits, from_text) |
| 1106 | + |
| 1107 | + # If from text and not generating, the tokens are not returned for now |
| 1108 | + if not (from_text and not generate): |
| 1109 | + assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1] |
| 1110 | + # The convention is that the response only has new tokens |
| 1111 | + assert ( |
| 1112 | + td.tokens_response[..., : td.tokens.shape[-1]] |
| 1113 | + != td.tokens[..., : td.tokens_response.shape[-1]] |
| 1114 | + ).any(), (generate, from_text) |
| 1115 | + |
| 1116 | + @pytest.mark.parametrize( |
| 1117 | + "from_text, tokens, attention_mask", |
| 1118 | + [ |
| 1119 | + (True, None, None), |
| 1120 | + ( |
| 1121 | + False, |
| 1122 | + torch.randint(1024, (1, 10)), |
| 1123 | + torch.ones(1, 10, dtype=torch.int64), |
| 1124 | + ), |
| 1125 | + (False, torch.randint(1024, (1, 10)), None), |
| 1126 | + ], |
| 1127 | + ) |
| 1128 | + def test_from_vllm_logprobs(self, from_text, tokens, attention_mask): |
| 1129 | + from vllm import LLM |
| 1130 | + |
| 1131 | + model = LLM(model="facebook/opt-125m") |
| 1132 | + m_generate = from_vllm( |
| 1133 | + model, from_text=from_text, generate=True, return_log_probs=True |
| 1134 | + ) |
| 1135 | + m_logprobs = from_vllm(model, from_text=from_text, generate=False) |
| 1136 | + self._check_lps( |
| 1137 | + m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False |
| 1138 | + ) |
| 1139 | + |
| 1140 | + def _check_lps( |
| 1141 | + self, |
| 1142 | + model_generate, |
| 1143 | + model_logprobs, |
| 1144 | + tokens, |
| 1145 | + attention_mask, |
| 1146 | + from_text, |
| 1147 | + has_logits, |
| 1148 | + ): |
| 1149 | + # Checks that the log-probs gathered with generate=False equate those with generate=True |
| 1150 | + tdin_genetate = self._make_data( |
| 1151 | + model_generate, tokens, attention_mask, True, from_text, has_logits |
| 1152 | + ) |
| 1153 | + td_generate = model_generate(tdin_genetate) |
| 1154 | + tdin_logprobs = self._make_data( |
| 1155 | + model_logprobs, |
| 1156 | + tokens, |
| 1157 | + attention_mask, |
| 1158 | + False, |
| 1159 | + from_text, |
| 1160 | + has_logits, |
| 1161 | + tokens_response=td_generate.tokens_response, |
| 1162 | + text_response=td_generate.text_response, |
| 1163 | + ) |
| 1164 | + td_logprobs = model_logprobs(tdin_logprobs) |
| 1165 | + torch.testing.assert_close( |
| 1166 | + td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2 |
| 1167 | + ) |
957 | 1168 |
|
958 | 1169 |
|
959 | 1170 | if __name__ == "__main__":
|
|
0 commit comments