Skip to content

Commit 023c965

Browse files
authored
[Doc] Fix doc pipeline (#2992)
1 parent d5ba70a commit 023c965

File tree

8 files changed

+31
-14
lines changed

8 files changed

+31
-14
lines changed

docs/source/reference/collectors.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,15 @@ transformed, and applied, ensuring seamless integration with their existing infr
159159
VanillaWeightUpdater
160160
MultiProcessedWeightUpdater
161161
RayWeightUpdater
162-
DistributedWeightUpdater
162+
163+
.. currentmodule:: torchrl.collectors.distributed
164+
165+
.. autosummary::
166+
:toctree: generated/
167+
:template: rl_template.rst
168+
163169
RPCWeightUpdater
170+
DistributedWeightUpdater
164171

165172
Collectors and replay buffers interoperability
166173
----------------------------------------------

torchrl/collectors/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
SyncDataCollector,
1414
)
1515
from .weight_update import (
16-
MultiProcessedWeightUpdate,
16+
MultiProcessedWeightUpdater,
1717
RayWeightUpdater,
1818
VanillaWeightUpdater,
1919
WeightUpdaterBase,
@@ -24,7 +24,7 @@
2424
"WeightUpdaterBase",
2525
"VanillaWeightUpdater",
2626
"RayWeightUpdater",
27-
"MultiProcessedWeightUpdate",
27+
"MultiProcessedWeightUpdater",
2828
"aSyncDataCollector",
2929
"DataCollectorBase",
3030
"MultiaSyncDataCollector",

torchrl/collectors/collectors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
)
5252
from torchrl.collectors.utils import split_trajectories
5353
from torchrl.collectors.weight_update import (
54-
MultiProcessedWeightUpdate,
54+
MultiProcessedWeightUpdater,
5555
VanillaWeightUpdater,
5656
WeightUpdaterBase,
5757
)
@@ -2010,7 +2010,7 @@ def __init__(
20102010
self._policy_weights_dict[policy_device] = weights
20112011
self._get_weights_fn = get_weights_fn
20122012
if weight_updater is None:
2013-
weight_updater = MultiProcessedWeightUpdate(
2013+
weight_updater = MultiProcessedWeightUpdater(
20142014
get_server_weights=self._get_weights_fn,
20152015
policy_weights=self._policy_weights_dict,
20162016
)

torchrl/collectors/distributed/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
DistributedWeightUpdater,
1010
)
1111
from .ray import RayCollector
12-
from .rpc import RPCDataCollector
12+
from .rpc import RPCDataCollector, RPCWeightUpdater
1313
from .sync import DistributedSyncDataCollector
1414
from .utils import submitit_delayed_launcher
1515

@@ -19,7 +19,7 @@
1919
"DistributedWeightUpdater",
2020
"DistributedSyncDataCollector",
2121
"RPCDataCollector",
22-
"RPCDataCollector",
22+
"RPCWeightUpdater",
2323
"RayCollector",
2424
"submitit_delayed_launcher",
2525
]

torchrl/collectors/distributed/rpc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def __init__(
412412
)
413413
self._init()
414414
if weight_updater is None:
415-
weight_updater = RPCWeightUpdaterBase(
415+
weight_updater = RPCWeightUpdater(
416416
collector_infos=self.collector_infos,
417417
collector_class=self.collector_class,
418418
collector_rrefs=self.collector_rrefs,
@@ -810,7 +810,7 @@ def shutdown(self, timeout: float | None = None) -> None:
810810
self._shutdown = True
811811

812812

813-
class RPCWeightUpdaterBase(WeightUpdaterBase):
813+
class RPCWeightUpdater(WeightUpdaterBase):
814814
"""A remote weight updater for synchronizing policy weights across remote workers using RPC.
815815
816816
The `RPCWeightUpdater` class provides a mechanism for updating the weights of a policy

torchrl/collectors/llm/weight_update/vllm.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7+
import importlib.util
8+
79
import torch
810
import torch.cuda
911
import torch.distributed
@@ -13,7 +15,16 @@
1315

1416
from torchrl.collectors import WeightUpdaterBase
1517
from torchrl.modules.llm.backends.vllm import stateless_init_process_group
16-
from vllm.utils import get_open_port
18+
19+
_has_vllm = importlib.util.find_spec("vllm") is not None
20+
if _has_vllm:
21+
from vllm.utils import get_open_port
22+
else:
23+
24+
def get_open_port(): # noqa: D103
25+
raise ImportError(
26+
"vllm is not installed. Please install it with `pip install vllm`."
27+
)
1728

1829

1930
class vLLMUpdater(WeightUpdaterBase):

torchrl/collectors/weight_update.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _sync_weights_with_worker(
206206
self.policy_weights.update_(server_weights)
207207

208208

209-
class MultiProcessedWeightUpdate(WeightUpdaterBase):
209+
class MultiProcessedWeightUpdater(WeightUpdaterBase):
210210
"""A remote weight updater for synchronizing policy weights across multiple processes or devices.
211211
212212
The `MultiProcessedWeightUpdater` class provides a mechanism for updating the weights

torchrl/envs/llm/chat.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Any, Callable, Literal
88

99
import torch
10-
import transformers
1110
from tensordict import lazy_stack, TensorDict, TensorDictBase
1211
from torch.utils.data import DataLoader
1312
from torchrl.data import Composite, NonTensor
@@ -116,7 +115,7 @@ def __init__(
116115
batch_size: tuple | torch.Size | None = None,
117116
system_prompt: str | None = None,
118117
apply_template: bool | None = None,
119-
tokenizer: transformers.AutoTokenizer | None = None,
118+
tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
120119
template_kwargs: dict[str, Any] | None = None,
121120
system_role: str = "system",
122121
user_role: str = "user",
@@ -309,7 +308,7 @@ def __init__(
309308
batch_size_dl: int = 1,
310309
seed: int | None = None,
311310
group_repeats: bool = False,
312-
tokenizer: transformers.AutoTokenizer | None = None,
311+
tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
313312
device: torch.device | None = None,
314313
template_kwargs: dict[str, Any] | None = None,
315314
apply_template: bool | None = None,

0 commit comments

Comments
 (0)