Skip to content

Account for cache load factor in memory estimate #3035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchrec/distributed/planner/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
# with other devices such as the FE NIC.
HBM_TO_DDR_MEM_BW: float = 32 * 1024 * 1024 * 1024 / 1000 # bytes/ms
UVM_CACHING_RATIO: float = 0.2
KV_CACHING_RATIO: float = 0.2
BATCH_SIZE: int = 512

BATCHED_COPY_PERF_FACTOR: float = 2.455 # empirical studies
Expand Down
20 changes: 18 additions & 2 deletions torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import logging
import math
from math import ceil
from typing import cast, Dict, List, Optional, Tuple, Type

import torch
Expand All @@ -22,6 +23,7 @@
FULL_BLOCK_EMB_DIM,
HALF_BLOCK_PENALTY,
kernel_bw_lookup,
KV_CACHING_RATIO,
QUARTER_BLOCK_PENALTY,
UVM_CACHING_RATIO,
WEIGHTED_KERNEL_MULTIPLIER,
Expand Down Expand Up @@ -1021,6 +1023,11 @@ def estimate(
if constraints and constraints.key_value_params
else None
)
kv_cache_load_factor: float = (
sharder.fused_params.get("cache_load_factor", KV_CACHING_RATIO)
if sharder.fused_params
else KV_CACHING_RATIO
)

# hardcoded as 8 bytes
# input indices can be of int32, but in TBE they get converted to int64 anyway
Expand Down Expand Up @@ -1065,6 +1072,7 @@ def estimate(
is_inference=self._is_inference,
multipass_prefetch_max_pass=mpp_conf.num_passes if mpp_conf else None,
key_value_params=key_value_params,
kv_cache_load_factor=kv_cache_load_factor,
)
for shard, storage in zip(sharding_option.shards, shard_storages):
shard.storage = storage
Expand Down Expand Up @@ -1134,6 +1142,7 @@ def calculate_shard_storages(
is_inference: bool = False,
multipass_prefetch_max_pass: Optional[int] = None,
key_value_params: Optional[KeyValueParams] = None,
kv_cache_load_factor: float = KV_CACHING_RATIO,
) -> List[Storage]:
"""
Calculates estimated storage sizes for each sharded tensor, comprised of input,
Expand Down Expand Up @@ -1191,7 +1200,6 @@ def calculate_shard_storages(
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value,
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
}:
# TODO(wangj): for ssd/dram kv, most likely we use absolute L1 cache size instead of caching ratio, as denominator is huge
hbm_storage = round(ddr_storage * caching_ratio)
table_cached = True

Expand Down Expand Up @@ -1225,7 +1233,15 @@ def calculate_shard_storages(
)

hbm_specific_sizes = [
(key_value_params.max_l1_cache_size or 0) * 1024 * 1024
min(
(key_value_params.max_l1_cache_size or 0) * 1024 * 1024,
ceil(
tensor.shape[0] # num_embeddings
* kv_cache_load_factor
* tensor.element_size() # size of one column
* tensor.shape[1], # number of columns in embedding
),
)
for _ in hbm_specific_sizes
]
ddr_specific_sizes = [
Expand Down
182 changes: 182 additions & 0 deletions torchrec/distributed/planner/tests/test_planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,3 +634,185 @@ def test_planner_with_virtual_table(self) -> None:
self.assertTrue(
any("Min HBM: 0.256 GB on ranks [0, 1]" in line for line in stats)
)

constraints = {
**{
f"table_{i}": ParameterConstraints(
sharding_types=["row_wise"],
compute_kernels=["dram_virtual_table"],
key_value_params=KeyValueParams(
l2_cache_size=64, max_l1_cache_size=128
),
)
for i in range(table_count // 2)
},
**{
f"table_{i}": ParameterConstraints(
cache_params=CacheParams(algorithm=CacheAlgorithm.LRU),
)
for i in range(table_count // 2, table_count)
},
}

topology = Topology(
world_size=2,
hbm_cap=1024 * 1024 * 1024 * 2,
ddr_cap=1024 * 1024 * 1024 * 256,
compute_device="cuda",
)

planner = EmbeddingShardingPlanner(
topology=topology,
proposer=EmbeddingOffloadScaleupProposer(),
constraints=constraints,
)
sharding_plan = planner.plan(
module=model, sharders=[EmbeddingCollectionSharder()] # pyre-ignore
)

expected_ranks = [[0, 1], [0, 1], [0, 1], [0, 1]]
ranks = [
cast(List[int], param_shard.ranks)
for param_shard in cast(
EmbeddingModuleShardingPlan, sharding_plan.plan["sparse.ec"]
).values()
]
compute_kernels = {
param_shard.compute_kernel
for param_shard in cast(
EmbeddingModuleShardingPlan, sharding_plan.plan["sparse.ec"]
).values()
}
self.assertEqual(sorted(expected_ranks), sorted(ranks))
self.assertSetEqual(
{
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
},
compute_kernels,
)

tables = [
EmbeddingConfig(
num_embeddings=10000,
embedding_dim=64,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
use_virtual_table=True,
total_num_buckets=10,
)
for i in range(table_count // 2)
] + [
EmbeddingConfig(
num_embeddings=100_000,
embedding_dim=64,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(table_count // 2, table_count)
]

model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))

planner = EmbeddingShardingPlanner(
topology=topology,
proposer=EmbeddingOffloadScaleupProposer(),
constraints=constraints,
)

# L1 cache size > size of embedding table * default cache load factor

sharding_plan = planner.plan(
module=model, sharders=[EmbeddingCollectionSharder()] # pyre-ignore
)
for table_index in range(4):
shards = sharding_plan.plan["sparse.ec"][
f"table_{table_index}"
].sharding_spec.shards
self.assertEqual(len(shards), 2)
self.assertEqual(shards[0].shard_offsets, [0, 0])
self.assertEqual(
shards[0].shard_sizes,
[5000 if table_index < 2 else 50_000, 64],
)
self.assertEqual(
shards[1].shard_offsets,
[5000 if table_index < 2 else 50_000, 0],
)
self.assertEqual(
shards[1].shard_sizes,
[5000 if table_index < 2 else 50_000, 64],
)
stats: List[str] = cast(EmbeddingStats, planner._stats[0])._stats_table
# L1 cache size of 64GB > size of embedding table * cache load factor. We use the smaller value.
# L2 cache size is 128MB per shard per table
self.assertTrue(
any(
"dram_virtual_table: HBM: 0.002 GB, DDR: 256.0 GB" in line
for line in stats
)
)
self.assertTrue(
any(
"fused_uvm_caching: HBM: 0.011 GB, DDR: 0.048 GB" in line
for line in stats
)
)
self.assertTrue(
any("Max HBM: 0.007 GB on ranks [0, 1]" in line for line in stats)
)
self.assertTrue(
any("Min HBM: 0.007 GB on ranks [0, 1]" in line for line in stats)
)

# Override cache load factor
planner = EmbeddingShardingPlanner(
topology=topology,
proposer=EmbeddingOffloadScaleupProposer(),
constraints=constraints,
)
sharding_plan = planner.plan(
module=model,
sharders=[ # pyre-ignore
EmbeddingCollectionSharder(fused_params={"cache_load_factor": 0.5})
],
)
for table_index in range(4):
shards = sharding_plan.plan["sparse.ec"][
f"table_{table_index}"
].sharding_spec.shards
self.assertEqual(len(shards), 2)
self.assertEqual(shards[0].shard_offsets, [0, 0])
self.assertEqual(
shards[0].shard_sizes,
[5000 if table_index < 2 else 50_000, 64],
)
self.assertEqual(
shards[1].shard_offsets,
[5000 if table_index < 2 else 50_000, 0],
)
self.assertEqual(
shards[1].shard_sizes,
[5000 if table_index < 2 else 50_000, 64],
)
stats: List[str] = cast(EmbeddingStats, planner._stats[0])._stats_table
# L1 cache size of 64GB > size of embedding table * cache load factor. We use the smaller value.
# L2 cache size is 128MB per shard per table
self.assertTrue(
any(
"dram_virtual_table: HBM: 0.005 GB, DDR: 256.0 GB" in line
for line in stats
)
)
self.assertTrue(
any(
"fused_uvm_caching: HBM: 0.027 GB, DDR: 0.048 GB" in line
for line in stats
)
)
self.assertTrue(
any("Max HBM: 0.016 GB on ranks [0, 1]" in line for line in stats)
)
self.assertTrue(
any("Min HBM: 0.016 GB on ranks [0, 1]" in line for line in stats)
)
2 changes: 1 addition & 1 deletion torchrec/optim/clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def step(self, closure: Any = None) -> None:
torch.nn.utils.clip_grad_norm_(
replicate_params,
self._max_gradient,
norm_type=self._norm_type,
norm_type=float(self._norm_type),
)
else:
self.clip_grad_norm_()
Expand Down
Loading