Skip to content

Commit 4e825b1

Browse files
q10facebook-github-bot
authored andcommitted
Update the rowwise adagrad optimizer to leverage optimizer state offloading, v4, frontend (#4249)
Summary: X-link: facebookresearch/FBGEMM#1328 - Follow up to D75329024by plumbing the flag for optimizer state offloading to the TBE frontend Differential Revision: D75336208
1 parent ee920a6 commit 4e825b1

File tree

5 files changed

+42
-20
lines changed

5 files changed

+42
-20
lines changed

fbgemm_gpu/codegen/genscript/generate_backward_split.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import itertools
1212
import sys
13+
from copy import deepcopy
1314
from typing import List
1415

1516
try:
@@ -164,6 +165,10 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
164165
if not kwargs.get("dense"):
165166
# Generate CUDA autograd
166167

168+
# Extract the aux_args and ssd_aux_args for later use
169+
aux_args = kwargs["aux_args"]
170+
ssd_aux_args = kwargs["ssd_aux_args"]
171+
167172
for ssd in [True, False] if kwargs.get("has_ssd_support") else [False]:
168173
template_filepath = (
169174
"training/backward/embedding_backward_split_host_template.cpp"
@@ -195,6 +200,10 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
195200
)
196201

197202
if kwargs.get("has_cpu_support") or kwargs.get("has_gpu_support"):
203+
# Since the template file only uses aux_args, reset the key
204+
# based on whether we are generated for SSD variant or not
205+
kwargs["aux_args"] = ssd_aux_args if ssd else aux_args
206+
198207
# Generates Python invoker for CUDA + CPU, and PT2
199208
template = CodeTemplate.load(
200209
"training/python/split_embedding_codegen_lookup_invoker.template"
@@ -433,28 +442,44 @@ def generate() -> None:
433442
"mixed_D", # 6
434443
],
435444
}
436-
# ssd-specific argument
445+
446+
# SSD-specific arguments
437447
ssd_aux_bool = [
448+
# When set to true, the per-row optimizer state will offloaded to
449+
# the end of each row in the SSD cache.
438450
"enable_optimizer_offloading", # 7
439451
]
452+
440453
assert (
441454
list(aux_args.keys()) == aux_names
442455
), f"{aux_names} must match {aux_args.keys()}"
443456

457+
ssd_aux_args = deepcopy(aux_args)
458+
ssd_aux_args["aux_bool"].extend(ssd_aux_bool)
459+
444460
all_optimizers = []
445461
ssd_optimizers = []
446462

447463
for optimizer in optimizers:
448464
optim = optimizer["optimizer"]
465+
449466
if (
450467
optimizer["has_cpu_support"] or optimizer["has_gpu_support"]
451468
) and optim != "dense":
452469
all_optimizers.append(optim)
453470
if optimizer["has_ssd_support"]:
454471
ssd_optimizers.append(optim)
472+
455473
BackwardSplitGenerator.generate_backward_split(
456-
ssd_tensors=ssd_tensors, aux_args=aux_args, **optimizer
474+
ssd_tensors=ssd_tensors,
475+
# Both aux_args and ssd_aux_args will be passed in, since
476+
# generate_backward_split will generate both SSD and non-SSD
477+
# variants
478+
aux_args=aux_args,
479+
ssd_aux_args=ssd_aux_args,
480+
**optimizer,
457481
)
482+
458483
BackwardSplitGenerator.generate_rocm_backward_split()
459484

460485
# Generate common device kernels for backwards
@@ -465,11 +490,10 @@ def generate() -> None:
465490
BackwardSplitGenerator.generate_backward_indices()
466491

467492
# Generate headers for backwards
468-
BackwardSplitGenerator.generate_backward_header(aux_args, aux_names)
469-
aux_args["aux_bool"].extend(ssd_aux_bool)
470-
BackwardSplitGenerator.generate_backward_header(
471-
aux_args, aux_names, is_ssd=True
472-
)
493+
for is_ssd in [True, False]:
494+
BackwardSplitGenerator.generate_backward_header(
495+
(ssd_aux_args if is_ssd else aux_args), aux_names, is_ssd=is_ssd
496+
)
473497

474498
BackwardSplitGenerator.generate_python_sources(all_optimizers, ssd_optimizers)
475499

fbgemm_gpu/codegen/training/python/lookup_args.template

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class CommonArgs(NamedTuple):
4848
use_homogeneous_placements: bool
4949
{%- if ssd %}
5050
ssd_tensors: Dict[str, torch.Tensor]
51+
enable_optimizer_offloading: bool
5152
{%- endif %}
5253
learning_rate_tensor: torch.Tensor
5354
info_B_num_bits: int

fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
9292
"Please check the frontend and backend version. "
9393
)
9494
{{ arg_type }}.append(dict_{{ arg_type }}["{{ var }}"])
95+
9596
{%- endfor %}
9697
{%- endmacro %}
9798

@@ -203,12 +204,9 @@ def invoke(
203204
"use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
204205
"use_homogeneous_placements": common_args.use_homogeneous_placements,
205206
"apply_global_weight_decay": apply_global_weight_decay,
206-
{%- if not ssd %}
207-
"mixed_D": mixed_D
208-
{%- else %}
209207
"mixed_D": mixed_D,
210-
# TODO: Update this when frontend is ready to land
211-
"enable_optimizer_offloading": False
208+
{%- if ssd %}
209+
"enable_optimizer_offloading": common_args.enable_optimizer_offloading,
212210
{%- endif %}
213211
}
214212
dict_optim_int: Dict[str, int] = {}

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,6 +1828,7 @@ def forward(
18281828
"post_bwd_evicted_indices": post_bwd_evicted_indices_cpu,
18291829
"actions_count": actions_count_cpu,
18301830
},
1831+
enable_optimizer_offloading=self.enable_optimizer_offloading,
18311832
# pyre-fixme[6]: Expected `lookup_args_ssd.VBEMetadata` but got `lookup_args.VBEMetadata`
18321833
vbe_metadata=vbe_metadata,
18331834
learning_rate_tensor=self.learning_rate_tensor,

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,9 +1595,7 @@ def test_kv_db_forward(
15951595
@given(
15961596
**default_st,
15971597
num_buckets=st.integers(min_value=10, max_value=15),
1598-
opt_offloading=st.just(
1599-
False
1600-
), # make it st.booleans when Benson's opt offloading diff is landed
1598+
enable_optimizer_offloading=st.booleans(),
16011599
backend_type=st.sampled_from([BackendType.SSD, BackendType.DRAM]),
16021600
)
16031601
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
@@ -1617,7 +1615,7 @@ def test_kv_emb_state_dict(
16171615
trigger_bounds_check: bool,
16181616
mixed_B: bool,
16191617
num_buckets: int,
1620-
opt_offloading: bool,
1618+
enable_optimizer_offloading: bool,
16211619
backend_type: BackendType,
16221620
) -> None:
16231621
# Constants
@@ -1653,7 +1651,7 @@ def test_kv_emb_state_dict(
16531651
output_dtype=output_dtype,
16541652
share_table=share_table,
16551653
num_buckets=num_buckets,
1656-
enable_optimizer_offloading=opt_offloading,
1654+
enable_optimizer_offloading=enable_optimizer_offloading,
16571655
backend_type=backend_type,
16581656
)
16591657

@@ -1791,8 +1789,6 @@ def test_kv_emb_state_dict(
17911789
self.assertLess(table_index, len(emb_state_dict_list))
17921790
assert len(split_optimizer_states[table_index]) == num_ids
17931791
opt = split_optimizer_states[table_index]
1794-
if opt_offloading:
1795-
opt = opt[bucket_asc_ids_list[table_index].view(-1)]
17961792
new_ref_weight = torch.addcdiv(
17971793
emb_r_w.float(),
17981794
value=-lr,
@@ -1822,6 +1818,7 @@ def test_kv_emb_state_dict(
18221818
@given(
18231819
**default_st,
18241820
num_buckets=st.integers(min_value=10, max_value=15),
1821+
enable_optimizer_offloading=st.booleans(),
18251822
)
18261823
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
18271824
def test_kv_opt_state_w_offloading(
@@ -1840,6 +1837,7 @@ def test_kv_opt_state_w_offloading(
18401837
trigger_bounds_check: bool,
18411838
mixed_B: bool,
18421839
num_buckets: int,
1840+
enable_optimizer_offloading: bool,
18431841
) -> None:
18441842
# Constants
18451843
lr = 0.5
@@ -1875,7 +1873,7 @@ def test_kv_opt_state_w_offloading(
18751873
output_dtype=output_dtype,
18761874
share_table=share_table,
18771875
num_buckets=num_buckets,
1878-
enable_optimizer_offloading=False,
1876+
enable_optimizer_offloading=enable_optimizer_offloading,
18791877
)
18801878

18811879
# Generate inputs

0 commit comments

Comments
 (0)