Skip to content

Commit

Permalink
A tutorial on pin_memory and non_blocking usage (#2983)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 31, 2024
1 parent c3882db commit a66464b
Show file tree
Hide file tree
Showing 11 changed files with 770 additions and 22 deletions.
5 changes: 3 additions & 2 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ tensorboard
jinja2==3.1.3
pytorch-lightning
torchx
torchrl==0.3.0
tensordict==0.3.0
# TODO: use stable 0.5 when released
-e git+https://github.com/pytorch/rl.git#egg=torchrl
-e git+https://github.com/pytorch/tensordict.git#egg=tensordict
ax-platform
nbformat>==5.9.2
datasets
Expand Down
Binary file added _static/img/pinmem/pinmem.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added _static/img/pinmem/trace_streamed0_pinned0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added _static/img/pinmem/trace_streamed0_pinned1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added _static/img/pinmem/trace_streamed1_pinned0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added _static/img/pinmem/trace_streamed1_pinned1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 18 additions & 16 deletions advanced_source/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@
# Later, we will see how the target parameters should be updated in TorchRL.
#

from tensordict.nn import TensorDictModule
from tensordict.nn import TensorDictModule, TensorDictSequential


def _init(
Expand Down Expand Up @@ -290,12 +290,11 @@ def _loss_actor(
) -> torch.Tensor:
td_copy = tensordict.select(*self.actor_in_keys)
# Get an action from the actor network: since we made it functional, we need to pass the params
td_copy = self.actor_network(td_copy, params=self.actor_network_params)
with self.actor_network_params.to_module(self.actor_network):
td_copy = self.actor_network(td_copy)
# get the value associated with that action
td_copy = self.value_network(
td_copy,
params=self.value_network_params.detach(),
)
with self.value_network_params.detach().to_module(self.value_network):
td_copy = self.value_network(td_copy)
return -td_copy.get("state_action_value")


Expand All @@ -317,7 +316,8 @@ def _loss_value(
td_copy = tensordict.clone()

# V(s, a)
self.value_network(td_copy, params=self.value_network_params)
with self.value_network_params.to_module(self.value_network):
self.value_network(td_copy)
pred_val = td_copy.get("state_action_value").squeeze(-1)

# we manually reconstruct the parameters of the actor-critic, where the first
Expand All @@ -332,9 +332,8 @@ def _loss_value(
batch_size=self.target_actor_network_params.batch_size,
device=self.target_actor_network_params.device,
)
target_value = self.value_estimator.value_estimate(
tensordict, target_params=target_params
).squeeze(-1)
with target_params.to_module(self.actor_critic):
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)

# Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function`
loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function)
Expand Down Expand Up @@ -717,7 +716,7 @@ def get_env_stats():
ActorCriticWrapper,
DdpgMlpActor,
DdpgMlpQNet,
OrnsteinUhlenbeckProcessWrapper,
OrnsteinUhlenbeckProcessModule,
ProbabilisticActor,
TanhDelta,
ValueOperator,
Expand Down Expand Up @@ -776,15 +775,18 @@ def make_ddpg_actor(
# Exploration
# ~~~~~~~~~~~
#
# The policy is wrapped in a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper`
# The policy is passed into a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessModule`
# exploration module, as suggested in the original paper.
# Let's define the number of frames before OU noise reaches its minimum value
annealing_frames = 1_000_000

actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
actor_model_explore = TensorDictSequential(
actor,
annealing_num_steps=annealing_frames,
).to(device)
OrnsteinUhlenbeckProcessModule(
spec=actor.spec.clone(),
annealing_num_steps=annealing_frames,
).to(device),
)
if device == torch.device("cpu"):
actor_model_explore.share_memory()

Expand Down Expand Up @@ -1168,7 +1170,7 @@ def ceil_div(x, y):
)

# update the exploration strategy
actor_model_explore.step(current_frames)
actor_model_explore[1].step(current_frames)

collector.shutdown()
del collector
Expand Down
14 changes: 11 additions & 3 deletions en-wordlist.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

ACL
ADI
AOT
Expand Down Expand Up @@ -50,6 +51,7 @@ DDP
DDPG
DDQN
DLRM
DMA
DNN
DQN
DataLoaders
Expand All @@ -68,6 +70,8 @@ Ecker
ExportDB
FC
FGSM
tensordict
DataLoader's
FLAVA
FSDP
FX
Expand Down Expand Up @@ -139,6 +143,7 @@ MKLDNN
MLP
MLPs
MNIST
MPS
MUC
MacBook
MacOS
Expand Down Expand Up @@ -219,6 +224,7 @@ STR
SVE
SciPy
Sequentials
Sharding
Sigmoid
SoTA
Sohn
Expand Down Expand Up @@ -254,6 +260,7 @@ VLDB
VQA
VS Code
ViT
Volterra
WMT
WSI
WSIs
Expand Down Expand Up @@ -336,11 +343,11 @@ dataset’s
deallocation
decompositions
decorrelated
devicemesh
deserialize
deserialized
desynchronization
deterministically
devicemesh
dimensionality
dir
discontiguous
Expand Down Expand Up @@ -384,6 +391,7 @@ hessian
hessians
histoencoder
histologically
homonymous
hotspot
hvp
hyperparameter
Expand Down Expand Up @@ -459,6 +467,7 @@ optimizer's
optimizers
otsu
overfitting
pageable
parallelizable
parallelization
parametrization
Expand Down Expand Up @@ -522,7 +531,6 @@ runtime
runtimes
scalable
sharded
Sharding
softmax
sparsified
sparsifier
Expand Down Expand Up @@ -609,4 +617,4 @@ warmstarting
warmup
webp
wsi
wsis
wsis
9 changes: 9 additions & 0 deletions index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Welcome to PyTorch Tutorials

**What's new in PyTorch tutorials?**

* `A guide on good usage of non_blocking and pin_memory() in PyTorch <https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html>`__
* `Introduction to Distributed Pipeline Parallelism <https://pytorch.org/tutorials/intermediate/pipelining_tutorial.html>`__
* `Introduction to Libuv TCPStore Backend <https://pytorch.org/tutorials/intermediate/TCPStore_libuv_backend.html>`__
* `Asynchronous Saving with Distributed Checkpoint (DCP) <https://pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html>`__
Expand Down Expand Up @@ -93,6 +94,13 @@ Welcome to PyTorch Tutorials
:link: intermediate/tensorboard_tutorial.html
:tags: Interpretability,Getting-Started,TensorBoard

.. customcarditem::
:header: Good usage of `non_blocking` and `pin_memory()` in PyTorch
:card_description: A guide on best practices to copy data from CPU to GPU.
:image: _static/img/pinmem.png
:link: intermediate/pinmem_nonblock.html
:tags: Getting-Started

.. Image/Video
.. customcarditem::
Expand Down Expand Up @@ -969,6 +977,7 @@ Additional Resources
beginner/pytorch_with_examples
beginner/nn_tutorial
intermediate/tensorboard_tutorial
intermediate/pinmem_nonblock

.. toctree::
:maxdepth: 2
Expand Down
2 changes: 1 addition & 1 deletion intermediate_source/dqn_with_rnn_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@
# either by passing a string or an action-spec. This allows us to use
# Categorical (sometimes called "sparse") encoding or the one-hot version of it.
#
qval = QValueModule(action_space=env.action_spec)
qval = QValueModule(spec=env.action_spec)

######################################################################
# .. note::
Expand Down
Loading

0 comments on commit a66464b

Please sign in to comment.