14
14
import re
15
15
import string
16
16
from collections import defaultdict
17
+ from contextlib import nullcontext
17
18
from functools import partial
18
19
from sys import platform
19
20
from typing import Any , Optional
33
34
TensorDictBase ,
34
35
)
35
36
from tensordict .nn import TensorDictModuleBase
36
- from tensordict .tensorclass import NonTensorStack , TensorClass
37
+ from tensordict .tensorclass import NonTensorData , NonTensorStack , TensorClass
37
38
from tensordict .utils import _unravel_key_to_tuple
38
39
from torch import nn
39
40
@@ -4630,6 +4631,7 @@ def __next__(self):
4630
4631
else :
4631
4632
return tensors
4632
4633
4634
+ @pytest .mark .skipif (not _has_transformers , reason = "test requires transformers" )
4633
4635
@pytest .mark .parametrize (
4634
4636
"str2str,stack_method" ,
4635
4637
[
@@ -4674,22 +4676,36 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
4674
4676
else :
4675
4677
env .check_env_specs (break_when_any_done = "both" )
4676
4678
4679
+ @pytest .mark .skipif (not _has_transformers , reason = "test requires transformers" )
4680
+ @pytest .mark .parametrize ("tokenizer" , [True , False ])
4677
4681
@pytest .mark .parametrize (
4678
- "str2str,stack_method" ,
4682
+ "str2str,no_stack, stack_method" ,
4679
4683
[
4680
- [True , None ],
4681
- [False , "as_padded_tensor" ],
4682
- # TODO: a bit experimental, fails with check_env_specs
4683
- # [False, "as_nested_tensor"],
4684
- [False , None ],
4684
+ [True , True , None ],
4685
+ [True , False , None ],
4686
+ [False , False , "as_padded_tensor" ],
4687
+ [False , False , None ],
4685
4688
],
4686
4689
)
4687
4690
@pytest .mark .parametrize ("batched" , [True , False ])
4688
4691
@pytest .mark .parametrize ("device" , [None , "cpu" ])
4689
4692
@pytest .mark .parametrize ("batch_size" , [0 , 4 ])
4690
4693
def test_llm_from_dataloader (
4691
- self , str2str , batched , stack_method , device , batch_size
4694
+ self ,
4695
+ str2str ,
4696
+ batched ,
4697
+ stack_method ,
4698
+ device ,
4699
+ batch_size ,
4700
+ tokenizer ,
4701
+ no_stack ,
4692
4702
):
4703
+ from transformers import AutoTokenizer
4704
+
4705
+ if tokenizer :
4706
+ tokenizer = AutoTokenizer .from_pretrained ("bert-base-uncased" )
4707
+ else :
4708
+ tokenizer = None
4693
4709
if str2str :
4694
4710
kwargs = {
4695
4711
"dataloader" : self .DummyDataLoader (batch_size = batch_size ),
@@ -4712,7 +4728,8 @@ def test_llm_from_dataloader(
4712
4728
"str2str" : str2str ,
4713
4729
"device" : device ,
4714
4730
"has_attention" : False ,
4715
- "no_stack" : False ,
4731
+ "no_stack" : no_stack ,
4732
+ "tokenizer" : tokenizer ,
4716
4733
}
4717
4734
)
4718
4735
env = LLMEnv .from_dataloader (** kwargs )
@@ -4725,12 +4742,17 @@ def test_llm_from_dataloader(
4725
4742
if batch_size > 0 :
4726
4743
4727
4744
def policy (td ):
4728
- if str2str :
4745
+ if str2str and tokenizer is None :
4729
4746
if not td .shape :
4730
- td [LLMEnv ._DEFAULT_ACTION_STR_KEY ] = "<nothing>"
4747
+ td [LLMEnv ._DEFAULT_ACTION_STR_KEY ] = NonTensorData (
4748
+ "<nothing>" , device = device
4749
+ )
4731
4750
else :
4732
4751
td [LLMEnv ._DEFAULT_ACTION_STR_KEY ] = NonTensorStack (
4733
- * ["<nothing>" for _ in range (td .shape [0 ])]
4752
+ * [
4753
+ NonTensorData ("<nothing>" , device = device )
4754
+ for _ in range (td .shape [0 ])
4755
+ ]
4734
4756
)
4735
4757
else :
4736
4758
td [LLMEnv ._DEFAULT_ACTION_TOKENS_KEY ] = torch .ones (
@@ -4742,34 +4764,48 @@ def policy(td):
4742
4764
# Tell the env that we want 3 sub-envs
4743
4765
r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = [3 ]))
4744
4766
assert r .ndim == 2
4745
- if str2str :
4767
+ if str2str and tokenizer is None :
4746
4768
assert isinstance (r [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ], str )
4747
4769
assert isinstance (r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ], str )
4748
- assert (
4749
- r [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4750
- == r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4751
- : - len (r [0 , 0 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4752
- ]
4753
- )
4754
- assert (
4755
- r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4756
- == r [0 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4757
- : - len (r [0 , 1 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4758
- ]
4759
- )
4760
- assert (
4761
- r [- 1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4762
- == r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4763
- : - len (r [- 1 , 0 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4764
- ]
4765
- )
4766
- assert (
4767
- r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4768
- == r [- 1 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4769
- : - len (r [- 1 , 1 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4770
- ]
4771
- )
4772
- else :
4770
+ should_fail = no_stack
4771
+ if should_fail :
4772
+ ctx = pytest .raises (AssertionError )
4773
+ else :
4774
+ ctx = nullcontext ()
4775
+ with ctx :
4776
+ assert (
4777
+ r [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4778
+ == r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4779
+ : - len (r [0 , 0 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4780
+ ]
4781
+ ), (
4782
+ r [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ],
4783
+ r [0 , 0 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ],
4784
+ r [0 , 0 ]["next" , LLMEnv ._DEFAULT_STR_KEY ],
4785
+ r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ],
4786
+ )
4787
+ with ctx :
4788
+ assert (
4789
+ r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4790
+ == r [0 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4791
+ : - len (r [0 , 1 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4792
+ ]
4793
+ )
4794
+ with ctx :
4795
+ assert (
4796
+ r [- 1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4797
+ == r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4798
+ : - len (r [- 1 , 0 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4799
+ ]
4800
+ )
4801
+ with ctx :
4802
+ assert (
4803
+ r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4804
+ == r [- 1 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4805
+ : - len (r [- 1 , 1 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4806
+ ]
4807
+ )
4808
+ elif tokenizer is None :
4773
4809
assert (
4774
4810
r [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4775
4811
== r [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
0 commit comments