Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A tutorial on pin_memory and non_blocking usage #2983

Merged
merged 56 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
0974b34
init
vmoens Jul 24, 2024
5f0b6ae
index.rst
vmoens Jul 24, 2024
72f8951
index.rst
vmoens Jul 24, 2024
4fe1b2d
spelling
vmoens Jul 24, 2024
5f57f50
black
vmoens Jul 24, 2024
67f4d1a
pegeable -> pageable
vmoens Jul 24, 2024
07456f9
update requirements
vmoens Jul 24, 2024
ffdcef9
update requirements
vmoens Jul 24, 2024
f6ba568
update requirements
vmoens Jul 24, 2024
85d07be
update requirements
vmoens Jul 24, 2024
9b4640a
amend
vmoens Jul 24, 2024
a688b90
amend
vmoens Jul 24, 2024
d706405
amend
vmoens Jul 24, 2024
5ae66ec
amend
vmoens Jul 25, 2024
dc86259
amend
vmoens Jul 25, 2024
4af8cae
amend
vmoens Jul 25, 2024
73ab3ba
amend
vmoens Jul 25, 2024
5cb0510
black
vmoens Jul 25, 2024
01e580d
amend
vmoens Jul 25, 2024
3ad73ad
amend
vmoens Jul 25, 2024
8b2ec64
amend
vmoens Jul 25, 2024
d980553
amend
vmoens Jul 25, 2024
ac5b5e4
amend
vmoens Jul 25, 2024
68241fe
amend
vmoens Jul 26, 2024
96e1582
amend
vmoens Jul 26, 2024
96e703a
amend
vmoens Jul 26, 2024
8a6f90d
amend
vmoens Jul 26, 2024
8aa882d
amend
vmoens Jul 26, 2024
f9471ef
Apply suggestions from code review
vmoens Jul 28, 2024
26bb669
Update intermediate_source/pinmem_nonblock.py
vmoens Jul 28, 2024
085ee0d
more explicit example
vmoens Jul 28, 2024
84373c8
pyspelling
vmoens Jul 29, 2024
171b350
adding stream exps
vmoens Jul 29, 2024
5138dea
fix filename
vmoens Jul 29, 2024
3869413
amend
vmoens Jul 29, 2024
b1488d5
amend
vmoens Jul 29, 2024
33236ec
amend
vmoens Jul 29, 2024
2fd193a
amend
vmoens Jul 29, 2024
392230a
amend
vmoens Jul 29, 2024
bff42d1
amend
vmoens Jul 29, 2024
c8f7e41
Merge remote-tracking branch 'origin/main' into pinmem-nonblock-tuto
vmoens Jul 29, 2024
e6b20b1
amend
vmoens Jul 29, 2024
d6318f7
amend
vmoens Jul 29, 2024
83abe5f
amend
vmoens Jul 29, 2024
4fe4bde
amend
vmoens Jul 29, 2024
ea53204
amend
vmoens Jul 29, 2024
0d6cba7
amend
vmoens Jul 29, 2024
69e98ea
amend
vmoens Jul 30, 2024
ed465bd
amend
vmoens Jul 30, 2024
d4169d4
Update intermediate_source/pinmem_nonblock.py
vmoens Jul 30, 2024
2f55eb8
Apply suggestions from code review
vmoens Jul 30, 2024
12d1b69
Update intermediate_source/pinmem_nonblock.py
vmoens Jul 30, 2024
8f4d6d7
edit index.rst
vmoens Jul 30, 2024
d3befe4
Merge remote-tracking branch 'origin/pinmem-nonblock-tuto' into pinme…
vmoens Jul 30, 2024
1dfe315
edit tensordict to() link
vmoens Jul 30, 2024
07f9932
address comments
vmoens Jul 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ tensorboard
jinja2==3.1.3
pytorch-lightning
torchx
torchrl==0.3.0
tensordict==0.3.0
# TODO: use stable 0.5 when released
-e git+https://github.com/pytorch/rl.git#egg=torchrl
-e git+https://github.com/pytorch/tensordict.git#egg=tensordict
ax-platform
nbformat>==5.9.2
datasets
Expand Down
Binary file added _static/img/pinmem/pinmem.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added _static/img/pinmem/trace_streamed0_pinned0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added _static/img/pinmem/trace_streamed0_pinned1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added _static/img/pinmem/trace_streamed1_pinned0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added _static/img/pinmem/trace_streamed1_pinned1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 18 additions & 16 deletions advanced_source/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@
# Later, we will see how the target parameters should be updated in TorchRL.
#

from tensordict.nn import TensorDictModule
from tensordict.nn import TensorDictModule, TensorDictSequential


def _init(
Expand Down Expand Up @@ -290,12 +290,11 @@ def _loss_actor(
) -> torch.Tensor:
td_copy = tensordict.select(*self.actor_in_keys)
# Get an action from the actor network: since we made it functional, we need to pass the params
td_copy = self.actor_network(td_copy, params=self.actor_network_params)
with self.actor_network_params.to_module(self.actor_network):
td_copy = self.actor_network(td_copy)
# get the value associated with that action
td_copy = self.value_network(
td_copy,
params=self.value_network_params.detach(),
)
with self.value_network_params.detach().to_module(self.value_network):
td_copy = self.value_network(td_copy)
return -td_copy.get("state_action_value")


Expand All @@ -317,7 +316,8 @@ def _loss_value(
td_copy = tensordict.clone()

# V(s, a)
self.value_network(td_copy, params=self.value_network_params)
with self.value_network_params.to_module(self.value_network):
self.value_network(td_copy)
pred_val = td_copy.get("state_action_value").squeeze(-1)

# we manually reconstruct the parameters of the actor-critic, where the first
Expand All @@ -332,9 +332,8 @@ def _loss_value(
batch_size=self.target_actor_network_params.batch_size,
device=self.target_actor_network_params.device,
)
target_value = self.value_estimator.value_estimate(
tensordict, target_params=target_params
).squeeze(-1)
with target_params.to_module(self.actor_critic):
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)

# Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function`
loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function)
Expand Down Expand Up @@ -717,7 +716,7 @@ def get_env_stats():
ActorCriticWrapper,
DdpgMlpActor,
DdpgMlpQNet,
OrnsteinUhlenbeckProcessWrapper,
OrnsteinUhlenbeckProcessModule,
ProbabilisticActor,
TanhDelta,
ValueOperator,
Expand Down Expand Up @@ -776,15 +775,18 @@ def make_ddpg_actor(
# Exploration
# ~~~~~~~~~~~
#
# The policy is wrapped in a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper`
# The policy is passed into a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessModule`
# exploration module, as suggested in the original paper.
# Let's define the number of frames before OU noise reaches its minimum value
annealing_frames = 1_000_000

actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
actor_model_explore = TensorDictSequential(
actor,
annealing_num_steps=annealing_frames,
).to(device)
OrnsteinUhlenbeckProcessModule(
spec=actor.spec.clone(),
annealing_num_steps=annealing_frames,
).to(device),
)
if device == torch.device("cpu"):
actor_model_explore.share_memory()

Expand Down Expand Up @@ -1168,7 +1170,7 @@ def ceil_div(x, y):
)

# update the exploration strategy
actor_model_explore.step(current_frames)
actor_model_explore[1].step(current_frames)

collector.shutdown()
del collector
Expand Down
14 changes: 11 additions & 3 deletions en-wordlist.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

ACL
ADI
AOT
Expand Down Expand Up @@ -50,6 +51,7 @@ DDP
DDPG
DDQN
DLRM
DMA
DNN
DQN
DataLoaders
Expand All @@ -68,6 +70,8 @@ Ecker
ExportDB
FC
FGSM
tensordict
DataLoader's
FLAVA
FSDP
FX
Expand Down Expand Up @@ -139,6 +143,7 @@ MKLDNN
MLP
MLPs
MNIST
MPS
MUC
MacBook
MacOS
Expand Down Expand Up @@ -219,6 +224,7 @@ STR
SVE
SciPy
Sequentials
Sharding
Sigmoid
SoTA
Sohn
Expand Down Expand Up @@ -254,6 +260,7 @@ VLDB
VQA
VS Code
ViT
Volterra
WMT
WSI
WSIs
Expand Down Expand Up @@ -336,11 +343,11 @@ dataset’s
deallocation
decompositions
decorrelated
devicemesh
deserialize
deserialized
desynchronization
deterministically
devicemesh
dimensionality
dir
discontiguous
Expand Down Expand Up @@ -384,6 +391,7 @@ hessian
hessians
histoencoder
histologically
homonymous
hotspot
hvp
hyperparameter
Expand Down Expand Up @@ -459,6 +467,7 @@ optimizer's
optimizers
otsu
overfitting
pageable
parallelizable
parallelization
parametrization
Expand Down Expand Up @@ -522,7 +531,6 @@ runtime
runtimes
scalable
sharded
Sharding
softmax
sparsified
sparsifier
Expand Down Expand Up @@ -609,4 +617,4 @@ warmstarting
warmup
webp
wsi
wsis
wsis
8 changes: 8 additions & 0 deletions index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ Welcome to PyTorch Tutorials
:link: intermediate/tensorboard_tutorial.html
:tags: Interpretability,Getting-Started,TensorBoard

.. customcarditem::
:header: Good usage of `non_blocking` and `pin_memory()` in PyTorch
:card_description: A guide on best practices to copy data from CPU to GPU.
:image: _static/img/pinmem.png
:link: intermediate/pinmem_nonblock.html
:tags: Getting-Started

.. Image/Video

.. customcarditem::
Expand Down Expand Up @@ -942,6 +949,7 @@ Additional Resources
beginner/basics/autogradqs_tutorial
beginner/basics/optimization_tutorial
beginner/basics/saveloadrun_tutorial
intermediate/pinmem_nonblock
advanced/custom_ops_landing_page

.. toctree::
Expand Down
2 changes: 1 addition & 1 deletion intermediate_source/dqn_with_rnn_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@
# either by passing a string or an action-spec. This allows us to use
# Categorical (sometimes called "sparse") encoding or the one-hot version of it.
#
qval = QValueModule(action_space=env.action_spec)
qval = QValueModule(spec=env.action_spec)

######################################################################
# .. note::
Expand Down
Loading
Loading