Skip to content

Commit 60442e6

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
all tensors in ModelInput should be on pinned memory for non-blocking device-to-host data transfer (#2985)
Summary: Pull Request resolved: #2985 # context * `KeyedJaggedTensor` has the method of `pin_memory` so there's no need to do the pin_memory manually. * The `pin_memory()` call for input KJTs are important for training. NOTE: It's recommended in the prod training scenario that `TrainModelInput` should be created on pinned memory for a fast transfer to gpu. For more on [pin_memory](https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html#pin-memory). * ModelInput example ``` if pin_memory: float_features = float_features.pin_memory() label = label.pin_memory() idlist_features: Optional[KeyedJaggedTensor] = ( None if idlist_features is None else idlist_features.pin_memory() ) idscore_features: Optional[KeyedJaggedTensor] = ( None if idscore_features is None else idscore_features.pin_memory() ) return ModelInput( float_features=float_features, idlist_features=idlist_features, idscore_features=idscore_features, label=label, ) ``` WARNING: All the tensors in `TrainModelInput` should be pinned in memory, not just the KJTs. Otherwise you'll find that cpu execution is still blocked by `_to_copy` even most of the (host-to-device) data transfer is non-blocking. {F1978313151} {F1978313156} Reviewed By: tao-jia Differential Revision: D74434209 fbshipit-source-id: c7ad466b8d278044b2e2b9dd8f89489545f3060a
1 parent a95e7e1 commit 60442e6

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

torchrec/distributed/test_utils/test_input.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,6 @@ def generate(
287287
offsets_dtype=offsets_dtype,
288288
lengths_dtype=lengths_dtype,
289289
all_zeros=all_zeros,
290-
pin_memory=pin_memory,
291290
)
292291
if tables is not None and len(tables) > 0
293292
else None
@@ -306,7 +305,6 @@ def generate(
306305
offsets_dtype=offsets_dtype,
307306
lengths_dtype=lengths_dtype,
308307
all_zeros=all_zeros,
309-
pin_memory=pin_memory,
310308
)
311309
if weighted_tables is not None and len(weighted_tables) > 0
312310
else None
@@ -317,8 +315,16 @@ def generate(
317315
else torch.rand((batch_size,), device=device)
318316
)
319317
if pin_memory:
318+
# all tensors in `ModelInput` should be on pinned memory otherwise
319+
# the `_to_copy` (host-to-device) data transfer still blocks cpu execution
320320
float_features = float_features.pin_memory()
321321
label = label.pin_memory()
322+
idlist_features: Optional[KeyedJaggedTensor] = (
323+
None if idlist_features is None else idlist_features.pin_memory()
324+
)
325+
idscore_features: Optional[KeyedJaggedTensor] = (
326+
None if idscore_features is None else idscore_features.pin_memory()
327+
)
322328
return ModelInput(
323329
float_features=float_features,
324330
idlist_features=idlist_features,
@@ -417,18 +423,12 @@ def _assemble_kjt(
417423
device: Optional[torch.device] = None,
418424
use_offsets: bool = False,
419425
offsets_dtype: torch.dtype = torch.int64,
420-
pin_memory: bool = False,
421426
) -> KeyedJaggedTensor:
422427
"""
423428
Assembles a KeyedJaggedTensor (KJT) from the provided per-feature lengths and indices.
424429
425430
This method is used to generate corresponding local_batches and global_batch KJTs.
426431
It concatenates the lengths and indices for each feature to form a complete KJT.
427-
428-
The `pin_memory()` call for all KJT tensors are important for training benchmark, and
429-
also valid argument for the prod training scenario: TrainModelInput should be created
430-
on pinned memory for a fast transfer to gpu. For more on pin_memory:
431-
https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html#pin-memory
432432
"""
433433

434434
lengths = torch.cat(lengths_per_feature)
@@ -440,11 +440,6 @@ def _assemble_kjt(
440440
[torch.tensor([0], device=device), lengths.cumsum(0)]
441441
).to(offsets_dtype)
442442
lengths = None
443-
if pin_memory:
444-
indices = indices.pin_memory()
445-
lengths = lengths.pin_memory() if lengths is not None else None
446-
weights = weights.pin_memory() if weights is not None else None
447-
offsets = offsets.pin_memory() if offsets is not None else None
448443
return KeyedJaggedTensor(features, indices, weights, lengths, offsets)
449444

450445
@staticmethod
@@ -463,7 +458,6 @@ def create_standard_kjt(
463458
offsets_dtype: torch.dtype = torch.int64,
464459
lengths_dtype: torch.dtype = torch.int64,
465460
all_zeros: bool = False,
466-
pin_memory: bool = False,
467461
) -> KeyedJaggedTensor:
468462
features, lengths_per_feature, indices_per_feature = (
469463
ModelInput._create_features_lengths_indices(
@@ -486,7 +480,6 @@ def create_standard_kjt(
486480
device=device,
487481
use_offsets=use_offsets,
488482
offsets_dtype=offsets_dtype,
489-
pin_memory=pin_memory,
490483
)
491484

492485
@staticmethod

0 commit comments

Comments
 (0)