Skip to content

Commit f29494e

Browse files
committed
[Feature] pass policy-factory in mp data collectors
ghstack-source-id: 3af3a995c48e0eb6ce1736a587b565fa1ac758c4 Pull Request resolved: #2859
1 parent 886745d commit f29494e

File tree

6 files changed

+279
-9
lines changed

6 files changed

+279
-9
lines changed
+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
Updating MPS weights in multiprocess/distributed data collectors
8+
================================================================
9+
10+
Overview of the Script
11+
----------------------
12+
13+
This script demonstrates a weight update in TorchRL.
14+
The script uses a custom `MPSRemoteWeightUpdater` class to update the weights of a policy network across multiple workers.
15+
16+
Key Features
17+
------------
18+
19+
- Multi-Worker Setup: The script creates two worker processes that collect data from a Gym environment
20+
("Pendulum-v1") using a policy network.
21+
- MPS (Metal Performance Shaders) Device: The policy network is placed on an MPS device.
22+
- Custom Weight Updater: The `MPSRemoteWeightUpdater` class is used to update the policy weights across workers. This
23+
class is necessary because MPS tensors cannot be sent over a pipe due to serialization/pickling issues in PyTorch.
24+
25+
Workaround for MPS Tensor Serialization Issue
26+
---------------------------------------------
27+
28+
In PyTorch, MPS tensors cannot be serialized or pickled, which means they cannot be sent over a pipe or shared between
29+
processes. To work around this issue, the MPSRemoteWeightUpdater class sends the policy weights on the CPU device
30+
instead of the MPS device. The local workers then copy the weights from the CPU device to the MPS device.
31+
32+
Script Flow
33+
-----------
34+
35+
1. Initialize the environment, policy network, and collector.
36+
2. Update the policy weights using the MPSRemoteWeightUpdater.
37+
3. Collect data from the environment using the policy network.
38+
4. Zero out the policy weights after a few iterations.
39+
5. Verify that the updated policy weights are being used by checking the actions generated by the policy network.
40+
41+
"""
42+
43+
import tensordict
44+
import torch
45+
from tensordict import TensorDictBase
46+
from tensordict.nn import TensorDictModule
47+
from torch import nn
48+
from torchrl.collectors import MultiSyncDataCollector, RemoteWeightUpdaterBase
49+
50+
from torchrl.envs.libs.gym import GymEnv
51+
52+
53+
class MPSRemoteWeightUpdater(RemoteWeightUpdaterBase):
54+
def __init__(self, policy_weights, num_workers):
55+
# Weights are on mps device, which cannot be shared
56+
self.policy_weights = policy_weights.data
57+
self.num_workers = num_workers
58+
59+
def _sync_weights_with_worker(
60+
self, worker_id: int | torch.device, server_weights: TensorDictBase
61+
) -> TensorDictBase:
62+
# Send weights on cpu - the local workers will do the cpu->mps copy
63+
self.collector.pipes[worker_id].send((server_weights, "update"))
64+
val, msg = self.collector.pipes[worker_id].recv()
65+
assert msg == "updated"
66+
return server_weights
67+
68+
def _get_server_weights(self) -> TensorDictBase:
69+
print((self.policy_weights == 0).all())
70+
return self.policy_weights.cpu()
71+
72+
def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
73+
print((server_weights == 0).all())
74+
return server_weights
75+
76+
def all_worker_ids(self) -> list[int] | list[torch.device]:
77+
return list(range(self.num_workers))
78+
79+
80+
if __name__ == "__main__":
81+
device = "mps"
82+
83+
def env_maker():
84+
return GymEnv("Pendulum-v1", device="cpu")
85+
86+
def policy_factory(device=device):
87+
return TensorDictModule(
88+
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
89+
).to(device=device)
90+
91+
policy = policy_factory()
92+
policy_weights = tensordict.from_module(policy)
93+
94+
collector = MultiSyncDataCollector(
95+
create_env_fn=[env_maker, env_maker],
96+
policy_factory=policy_factory,
97+
total_frames=2000,
98+
max_frames_per_traj=50,
99+
frames_per_batch=200,
100+
init_random_frames=-1,
101+
reset_at_each_iter=False,
102+
device=device,
103+
storing_device="cpu",
104+
remote_weights_updater=MPSRemoteWeightUpdater(policy_weights, 2),
105+
# use_buffers=False,
106+
# cat_results="stack",
107+
)
108+
109+
collector.update_policy_weights_()
110+
try:
111+
for i, data in enumerate(collector):
112+
if i == 2:
113+
print(data)
114+
assert (data["action"] != 0).any()
115+
# zero the policy
116+
policy_weights.data.zero_()
117+
collector.update_policy_weights_()
118+
elif i == 3:
119+
assert (data["action"] == 0).all(), data["action"]
120+
break
121+
finally:
122+
collector.shutdown()

test/test_collector.py

+69-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@
3939
prod,
4040
seed_generator,
4141
)
42-
from torchrl.collectors import aSyncDataCollector, SyncDataCollector
42+
from torchrl.collectors import (
43+
aSyncDataCollector,
44+
RemoteWeightUpdaterBase,
45+
SyncDataCollector,
46+
)
4347
from torchrl.collectors.collectors import (
4448
_Interruptor,
4549
MultiaSyncDataCollector,
@@ -146,6 +150,7 @@
146150
PYTHON_3_10 = sys.version_info.major == 3 and sys.version_info.minor == 10
147151
PYTHON_3_7 = sys.version_info.major == 3 and sys.version_info.minor == 7
148152
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
153+
_has_cuda = torch.cuda.is_available()
149154

150155

151156
class WrappablePolicy(nn.Module):
@@ -3476,6 +3481,69 @@ def __deepcopy_error__(*args, **kwargs):
34763481
raise RuntimeError("deepcopy not allowed")
34773482

34783483

3484+
class TestPolicyFactory:
3485+
class MPSRemoteWeightUpdater(RemoteWeightUpdaterBase):
3486+
def __init__(self, policy_weights, num_workers):
3487+
# Weights are on mps device, which cannot be shared
3488+
self.policy_weights = policy_weights.data
3489+
self.num_workers = num_workers
3490+
3491+
def _sync_weights_with_worker(
3492+
self, worker_id: int | torch.device, server_weights: TensorDictBase
3493+
) -> TensorDictBase:
3494+
# Send weights on cpu - the local workers will do the cpu->mps copy
3495+
self.collector.pipes[worker_id].send((server_weights, "update"))
3496+
val, msg = self.collector.pipes[worker_id].recv()
3497+
assert msg == "updated"
3498+
return server_weights
3499+
3500+
def _get_server_weights(self) -> TensorDictBase:
3501+
return self.policy_weights.cpu()
3502+
3503+
def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
3504+
return server_weights
3505+
3506+
def all_worker_ids(self) -> list[int] | list[torch.device]:
3507+
return list(range(self.num_workers))
3508+
3509+
@pytest.mark.skipif(not _has_cuda, reason="requires cuda another device than CPU.")
3510+
def test_weight_update(self):
3511+
device = "cuda:0"
3512+
env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
3513+
policy_factory = lambda: TensorDictModule(
3514+
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
3515+
).to(device)
3516+
policy = policy_factory()
3517+
policy_weights = TensorDict.from_module(policy)
3518+
3519+
collector = MultiSyncDataCollector(
3520+
create_env_fn=[env_maker, env_maker],
3521+
policy_factory=policy_factory,
3522+
total_frames=2000,
3523+
max_frames_per_traj=50,
3524+
frames_per_batch=200,
3525+
init_random_frames=-1,
3526+
reset_at_each_iter=False,
3527+
device=device,
3528+
storing_device="cpu",
3529+
remote_weights_updater=self.MPSRemoteWeightUpdater(policy_weights, 2),
3530+
)
3531+
3532+
collector.update_policy_weights_()
3533+
try:
3534+
for i, data in enumerate(collector):
3535+
if i == 2:
3536+
assert (data["action"] != 0).any()
3537+
# zero the policy
3538+
policy_weights.data.zero_()
3539+
collector.update_policy_weights_()
3540+
elif i == 3:
3541+
assert (data["action"] == 0).all(), data["action"]
3542+
break
3543+
finally:
3544+
collector.shutdown()
3545+
3546+
34793547
if __name__ == "__main__":
34803548
args, unknown = argparse.ArgumentParser().parse_known_args()
34813549
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/collectors/collectors.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,28 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
152152
trust_policy: bool
153153
compiled_policy: bool
154154
cudagraphed_policy: bool
155-
local_weights_updater: LocalWeightUpdaterBase | None = None
156-
remote_weights_updater: RemoteWeightUpdaterBase | None = None
155+
_local_weights_updater: LocalWeightUpdaterBase | None = None
156+
_remote_weights_updater: RemoteWeightUpdaterBase | None = None
157+
158+
@property
159+
def local_weight_updater(self) -> LocalWeightUpdaterBase:
160+
return self._local_weight_updater
161+
162+
@local_weight_updater.setter
163+
def local_weight_updater(self, value: LocalWeightUpdaterBase | None):
164+
if value is not None:
165+
value.register_collector(self)
166+
self._local_weight_updater = value
167+
168+
@property
169+
def remote_weight_updater(self) -> RemoteWeightUpdaterBase:
170+
return self._remote_weight_updater
171+
172+
@remote_weight_updater.setter
173+
def remote_weight_updater(self, value: RemoteWeightUpdaterBase | None):
174+
if value is not None:
175+
value.register_collector(self)
176+
self._remote_weight_updater = value
157177

158178
def _get_policy_and_device(
159179
self,
@@ -1515,7 +1535,7 @@ def __repr__(self) -> str:
15151535
f"\nexploration={self.exploration_type})"
15161536
)
15171537
return string
1518-
except AttributeError:
1538+
except Exception:
15191539
return f"{type(self).__name__}(not_init)"
15201540

15211541

@@ -1831,6 +1851,7 @@ def __init__(
18311851
self.local_weights_updater = local_weights_updater
18321852

18331853
self.policy = policy
1854+
self.policy_factory = policy_factory
18341855

18351856
remainder = 0
18361857
if total_frames is None or total_frames < 0:
@@ -2012,6 +2033,10 @@ def _run_processes(self) -> None:
20122033
env_fun = CloudpickleWrapper(env_fun)
20132034

20142035
# Create a policy on the right device
2036+
policy_factory = self.policy_factory
2037+
if policy_factory is not None:
2038+
policy_factory = CloudpickleWrapper(policy_factory)
2039+
20152040
policy_device = self.policy_device[i]
20162041
storing_device = self.storing_device[i]
20172042
env_device = self.env_device[i]
@@ -2020,13 +2045,14 @@ def _run_processes(self) -> None:
20202045
# This makes sure that a given set of shared weights for a given device are
20212046
# shared for all policies that rely on that device.
20222047
policy = self.policy
2023-
policy_weights = self._policy_weights_dict[policy_device]
2048+
policy_weights = self._policy_weights_dict.get(policy_device)
20242049
if policy is not None and policy_weights is not None:
20252050
cm = policy_weights.to_module(policy)
20262051
else:
20272052
cm = contextlib.nullcontext()
20282053
with cm:
20292054
kwargs = {
2055+
"policy_factory": policy_factory,
20302056
"pipe_parent": pipe_parent,
20312057
"pipe_child": pipe_child,
20322058
"queue_out": queue_out,
@@ -3107,6 +3133,7 @@ def _main_async_collector(
31073133
compile_policy: bool = False,
31083134
cudagraph_policy: bool = False,
31093135
no_cuda_sync: bool = False,
3136+
policy_factory: Callable | None = None,
31103137
) -> None:
31113138
pipe_parent.close()
31123139
# init variables that will be cleared when closing
@@ -3116,6 +3143,7 @@ def _main_async_collector(
31163143
create_env_fn,
31173144
create_env_kwargs=create_env_kwargs,
31183145
policy=policy,
3146+
policy_factory=policy_factory,
31193147
total_frames=-1,
31203148
max_frames_per_traj=max_frames_per_traj,
31213149
frames_per_batch=frames_per_batch,
@@ -3278,7 +3306,7 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR):
32783306
continue
32793307

32803308
elif msg == "update":
3281-
inner_collector.update_policy_weights_()
3309+
inner_collector.update_policy_weights_(policy_weights=data_in)
32823310
pipe_child.send((j, "updated"))
32833311
has_timed_out = False
32843312
continue

torchrl/collectors/weight_update.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from __future__ import annotations
66

77
import abc
8+
import weakref
89
from abc import abstractmethod
9-
from typing import Callable, TypeVar
10+
from typing import Any, Callable, TypeVar
1011

1112
import torch
1213
from tensordict import TensorDictBase
@@ -44,6 +45,25 @@ class LocalWeightUpdaterBase(metaclass=abc.ABCMeta):
4445
4546
"""
4647

48+
_collector_wr: Any = None
49+
50+
def register_collector(self, collector: DataCollectorBase): # noqa
51+
"""Register a collector in the updater.
52+
53+
Once registered, the updater will not accept another collector.
54+
55+
Args:
56+
collector (DataCollectorBase): The collector to register.
57+
58+
"""
59+
if self._collector_wr is not None:
60+
raise RuntimeError("Cannot register collector twice.")
61+
self._collector_wr = weakref.ref(collector)
62+
63+
@property
64+
def collector(self) -> torchrl.collectors.DataCollectorBase: # noqa
65+
return self._collector_wr() if self._collector_wr is not None else None
66+
4767
@abstractmethod
4868
def _get_server_weights(self) -> TensorDictBase:
4969
...
@@ -104,12 +124,33 @@ class RemoteWeightUpdaterBase(metaclass=abc.ABCMeta):
104124
105125
Methods:
106126
update_weights: Updates the weights on specified or all remote workers.
127+
register_collector: Registers a collector. This should be called automatically by the collector
128+
upon registration of the updater.
107129
108130
.. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and
109131
:meth:`~torchrl.collectors.DataCollectorBase.update_policy_weights_`.
110132
111133
"""
112134

135+
_collector_wr: Any = None
136+
137+
def register_collector(self, collector: DataCollectorBase): # noqa
138+
"""Register a collector in the updater.
139+
140+
Once registered, the updater will not accept another collector.
141+
142+
Args:
143+
collector (DataCollectorBase): The collector to register.
144+
145+
"""
146+
if self._collector_wr is not None:
147+
raise RuntimeError("Cannot register collector twice.")
148+
self._collector_wr = weakref.ref(collector)
149+
150+
@property
151+
def collector(self) -> DataCollectorBase:
152+
return self._collector_wr() if self._collector_wr is not None else None
153+
113154
@abstractmethod
114155
def _sync_weights_with_worker(
115156
self, worker_id: int | torch.device, server_weights: TensorDictBase

0 commit comments

Comments
 (0)