|
16 | 16 | from torch import nn
|
17 | 17 | from torch.nn.modules.module import _IncompatibleKeys
|
18 | 18 | from torch.nn.parallel import DistributedDataParallel
|
| 19 | +from torchrec.distributed.embedding import ( |
| 20 | + EmbeddingCollectionContext, |
| 21 | + EmbeddingCollectionSharder, |
| 22 | + ShardedEmbeddingCollection, |
| 23 | +) |
19 | 24 |
|
20 | 25 | from torchrec.distributed.embedding_types import (
|
21 | 26 | BaseEmbeddingSharder,
|
|
36 | 41 | ShardingType,
|
37 | 42 | )
|
38 | 43 | from torchrec.distributed.utils import filter_state_dict
|
39 |
| -from torchrec.modules.itep_embedding_modules import ITEPEmbeddingBagCollection |
| 44 | +from torchrec.modules.itep_embedding_modules import ( |
| 45 | + ITEPEmbeddingBagCollection, |
| 46 | + ITEPEmbeddingCollection, |
| 47 | +) |
40 | 48 | from torchrec.modules.itep_modules import GenericITEPModule, RowwiseShardedITEPModule
|
41 |
| -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor |
| 49 | +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor |
42 | 50 |
|
43 | 51 |
|
44 | 52 | @dataclass
|
@@ -314,3 +322,248 @@ def module_type(self) -> Type[ITEPEmbeddingBagCollection]:
|
314 | 322 | def sharding_types(self, compute_device_type: str) -> List[str]:
|
315 | 323 | types = list(SHARDING_TYPE_TO_GROUP.keys())
|
316 | 324 | return types
|
| 325 | + |
| 326 | + |
| 327 | +class ITEPEmbeddingCollectionContext(EmbeddingCollectionContext): |
| 328 | + |
| 329 | + def __init__(self) -> None: |
| 330 | + super().__init__() |
| 331 | + self.is_reindexed: bool = False |
| 332 | + self.table_name_to_unpruned_hash_sizes: Dict[str, int] = {} |
| 333 | + |
| 334 | + |
| 335 | +class ShardedITEPEmbeddingCollection( |
| 336 | + ShardedEmbeddingModule[ |
| 337 | + KJTList, |
| 338 | + List[torch.Tensor], |
| 339 | + Dict[str, JaggedTensor], |
| 340 | + ITEPEmbeddingCollectionContext, |
| 341 | + ] |
| 342 | +): |
| 343 | + def __init__( |
| 344 | + self, |
| 345 | + module: ITEPEmbeddingCollection, |
| 346 | + table_name_to_parameter_sharding: Dict[str, ParameterSharding], |
| 347 | + ebc_sharder: EmbeddingCollectionSharder, |
| 348 | + env: ShardingEnv, |
| 349 | + device: torch.device, |
| 350 | + ) -> None: |
| 351 | + super().__init__() |
| 352 | + |
| 353 | + self._device = device |
| 354 | + self._env = env |
| 355 | + self.table_name_to_unpruned_hash_sizes: Dict[str, int] = ( |
| 356 | + module._itep_module.table_name_to_unpruned_hash_sizes |
| 357 | + ) |
| 358 | + |
| 359 | + # Iteration counter for ITEP Module. Pinning on CPU because used for condition checking and checkpointing. |
| 360 | + self.register_buffer( |
| 361 | + "_iter", torch.tensor(0, dtype=torch.int64, device=torch.device("cpu")) |
| 362 | + ) |
| 363 | + |
| 364 | + self._embedding_collection: ShardedEmbeddingCollection = ebc_sharder.shard( |
| 365 | + module._embedding_collection, |
| 366 | + table_name_to_parameter_sharding, |
| 367 | + env=env, |
| 368 | + device=device, |
| 369 | + ) |
| 370 | + |
| 371 | + self.table_name_to_sharding_type: Dict[str, str] = {} |
| 372 | + for table_name in table_name_to_parameter_sharding.keys(): |
| 373 | + self.table_name_to_sharding_type[table_name] = ( |
| 374 | + table_name_to_parameter_sharding[table_name].sharding_type |
| 375 | + ) |
| 376 | + |
| 377 | + # Group lookups, table_name_to_unpruned_hash_sizes by sharding type and pass to separate itep modules |
| 378 | + (grouped_lookups, grouped_table_unpruned_size_map) = ( |
| 379 | + self._group_lookups_and_table_unpruned_size_map( |
| 380 | + module._itep_module.table_name_to_unpruned_hash_sizes, |
| 381 | + ) |
| 382 | + ) |
| 383 | + |
| 384 | + # Instantiate ITEP Module in sharded case, re-using metadata from non-sharded case |
| 385 | + self._itep_module: GenericITEPModule = GenericITEPModule( |
| 386 | + table_name_to_unpruned_hash_sizes=grouped_table_unpruned_size_map[ |
| 387 | + ShardingTypeGroup.CW_GROUP |
| 388 | + ], |
| 389 | + lookups=grouped_lookups[ShardingTypeGroup.CW_GROUP], |
| 390 | + pruning_interval=module._itep_module.pruning_interval, |
| 391 | + enable_pruning=module._itep_module.enable_pruning, |
| 392 | + ) |
| 393 | + self._rowwise_itep_module: RowwiseShardedITEPModule = RowwiseShardedITEPModule( |
| 394 | + table_name_to_unpruned_hash_sizes=grouped_table_unpruned_size_map[ |
| 395 | + ShardingTypeGroup.RW_GROUP |
| 396 | + ], |
| 397 | + lookups=grouped_lookups[ShardingTypeGroup.RW_GROUP], |
| 398 | + pruning_interval=module._itep_module.pruning_interval, |
| 399 | + table_name_to_sharding_type=self.table_name_to_sharding_type, |
| 400 | + enable_pruning=module._itep_module.enable_pruning, |
| 401 | + ) |
| 402 | + |
| 403 | + # pyre-ignore |
| 404 | + def input_dist( |
| 405 | + self, |
| 406 | + ctx: ITEPEmbeddingCollectionContext, |
| 407 | + features: KeyedJaggedTensor, |
| 408 | + force_insert: bool = False, |
| 409 | + ) -> Awaitable[Awaitable[KJTList]]: |
| 410 | + |
| 411 | + ctx.table_name_to_unpruned_hash_sizes = self.table_name_to_unpruned_hash_sizes |
| 412 | + return self._embedding_collection.input_dist(ctx, features) |
| 413 | + |
| 414 | + def compute( |
| 415 | + self, |
| 416 | + ctx: ITEPEmbeddingCollectionContext, |
| 417 | + dist_input: KJTList, |
| 418 | + ) -> List[torch.Tensor]: |
| 419 | + for i, (sharding, features) in enumerate( |
| 420 | + zip( |
| 421 | + self._embedding_collection._sharding_type_to_sharding.keys(), |
| 422 | + dist_input, |
| 423 | + ) |
| 424 | + ): |
| 425 | + if SHARDING_TYPE_TO_GROUP[sharding] == ShardingTypeGroup.CW_GROUP: |
| 426 | + remapped_kjt = self._itep_module(features, self._iter.item()) |
| 427 | + else: |
| 428 | + remapped_kjt = self._rowwise_itep_module(features, self._iter.item()) |
| 429 | + dist_input[i] = remapped_kjt |
| 430 | + self._iter += 1 |
| 431 | + return self._embedding_collection.compute(ctx, dist_input) |
| 432 | + |
| 433 | + def output_dist( |
| 434 | + self, |
| 435 | + ctx: ITEPEmbeddingCollectionContext, |
| 436 | + output: List[torch.Tensor], |
| 437 | + ) -> LazyAwaitable[Dict[str, JaggedTensor]]: |
| 438 | + |
| 439 | + ec_awaitable = self._embedding_collection.output_dist(ctx, output) |
| 440 | + return ec_awaitable |
| 441 | + |
| 442 | + def compute_and_output_dist( |
| 443 | + self, ctx: ITEPEmbeddingCollectionContext, input: KJTList |
| 444 | + ) -> LazyAwaitable[Dict[str, JaggedTensor]]: |
| 445 | + # Insert forward() function of GenericITEPModule into compute_and_output_dist() |
| 446 | + """ """ |
| 447 | + for i, (sharding, features) in enumerate( |
| 448 | + zip( |
| 449 | + self._embedding_collection._sharding_type_to_sharding.keys(), |
| 450 | + input, |
| 451 | + ) |
| 452 | + ): |
| 453 | + if SHARDING_TYPE_TO_GROUP[sharding] == ShardingTypeGroup.CW_GROUP: |
| 454 | + remapped_kjt = self._itep_module(features, self._iter.item()) |
| 455 | + else: |
| 456 | + remapped_kjt = self._rowwise_itep_module(features, self._iter.item()) |
| 457 | + input[i] = remapped_kjt |
| 458 | + self._iter += 1 |
| 459 | + ec_awaitable = self._embedding_collection.compute_and_output_dist(ctx, input) |
| 460 | + return ec_awaitable |
| 461 | + |
| 462 | + def create_context(self) -> ITEPEmbeddingCollectionContext: |
| 463 | + return ITEPEmbeddingCollectionContext() |
| 464 | + |
| 465 | + # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module` |
| 466 | + # inconsistently. |
| 467 | + def load_state_dict( |
| 468 | + self, |
| 469 | + state_dict: "OrderedDict[str, torch.Tensor]", |
| 470 | + strict: bool = True, |
| 471 | + ) -> _IncompatibleKeys: |
| 472 | + missing_keys = [] |
| 473 | + unexpected_keys = [] |
| 474 | + self._iter = state_dict["_iter"] |
| 475 | + for name, child_module in self._modules.items(): |
| 476 | + if child_module is not None: |
| 477 | + missing, unexpected = child_module.load_state_dict( |
| 478 | + filter_state_dict(state_dict, name), |
| 479 | + strict, |
| 480 | + ) |
| 481 | + missing_keys.extend(missing) |
| 482 | + unexpected_keys.extend(unexpected) |
| 483 | + return _IncompatibleKeys( |
| 484 | + missing_keys=missing_keys, unexpected_keys=unexpected_keys |
| 485 | + ) |
| 486 | + |
| 487 | + def _group_lookups_and_table_unpruned_size_map( |
| 488 | + self, table_name_to_unpruned_hash_sizes: Dict[str, int] |
| 489 | + ) -> Tuple[ |
| 490 | + Dict[ShardingTypeGroup, List[nn.Module]], |
| 491 | + Dict[ShardingTypeGroup, Dict[str, int]], |
| 492 | + ]: |
| 493 | + """ |
| 494 | + Group ebc lookups and table_name_to_unpruned_hash_sizes by sharding types. |
| 495 | + CW and TW are grouped into CW_GROUP, RW and TWRW are grouped into RW_GROUP. |
| 496 | +
|
| 497 | + Return a tuple of (grouped_lookups, grouped _table_unpruned_size_map) |
| 498 | + """ |
| 499 | + grouped_lookups: Dict[ShardingTypeGroup, List[nn.Module]] = defaultdict(list) |
| 500 | + grouped_table_unpruned_size_map: Dict[ShardingTypeGroup, Dict[str, int]] = ( |
| 501 | + defaultdict(dict) |
| 502 | + ) |
| 503 | + for sharding_type, lookup in zip( |
| 504 | + self._embedding_collection._sharding_types, |
| 505 | + self._embedding_collection._lookups, |
| 506 | + ): |
| 507 | + sharding_group = SHARDING_TYPE_TO_GROUP[sharding_type] |
| 508 | + # group lookups |
| 509 | + grouped_lookups[sharding_group].append(lookup) |
| 510 | + # group table_name_to_unpruned_hash_sizes |
| 511 | + while isinstance(lookup, DistributedDataParallel): |
| 512 | + lookup = lookup.module |
| 513 | + for emb_config in lookup.grouped_configs: |
| 514 | + for table in emb_config.embedding_tables: |
| 515 | + if table.name in table_name_to_unpruned_hash_sizes.keys(): |
| 516 | + grouped_table_unpruned_size_map[sharding_group][table.name] = ( |
| 517 | + table_name_to_unpruned_hash_sizes[table.name] |
| 518 | + ) |
| 519 | + |
| 520 | + return grouped_lookups, grouped_table_unpruned_size_map |
| 521 | + |
| 522 | + |
| 523 | +class ITEPEmbeddingCollectionSharder(BaseEmbeddingSharder[ITEPEmbeddingCollection]): |
| 524 | + def __init__( |
| 525 | + self, |
| 526 | + ebc_sharder: Optional[EmbeddingCollectionSharder] = None, |
| 527 | + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, |
| 528 | + ) -> None: |
| 529 | + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) |
| 530 | + self._ebc_sharder: EmbeddingCollectionSharder = ( |
| 531 | + ebc_sharder |
| 532 | + or EmbeddingCollectionSharder( |
| 533 | + qcomm_codecs_registry=self.qcomm_codecs_registry |
| 534 | + ) |
| 535 | + ) |
| 536 | + |
| 537 | + def shard( |
| 538 | + self, |
| 539 | + module: ITEPEmbeddingCollection, |
| 540 | + params: Dict[str, ParameterSharding], |
| 541 | + env: ShardingEnv, |
| 542 | + device: Optional[torch.device] = None, |
| 543 | + module_fqn: Optional[str] = None, |
| 544 | + ) -> ShardedITEPEmbeddingCollection: |
| 545 | + |
| 546 | + # Enforce GPU for ITEPEmbeddingBagCollection |
| 547 | + if device is None: |
| 548 | + device = torch.device("cuda") |
| 549 | + |
| 550 | + return ShardedITEPEmbeddingCollection( |
| 551 | + module, |
| 552 | + params, |
| 553 | + ebc_sharder=self._ebc_sharder, |
| 554 | + env=env, |
| 555 | + device=device, |
| 556 | + ) |
| 557 | + |
| 558 | + def shardable_parameters( |
| 559 | + self, module: ITEPEmbeddingCollection |
| 560 | + ) -> Dict[str, torch.nn.Parameter]: |
| 561 | + return self._ebc_sharder.shardable_parameters(module._embedding_collection) |
| 562 | + |
| 563 | + @property |
| 564 | + def module_type(self) -> Type[ITEPEmbeddingCollection]: |
| 565 | + return ITEPEmbeddingCollection |
| 566 | + |
| 567 | + def sharding_types(self, compute_device_type: str) -> List[str]: |
| 568 | + types = list(SHARDING_TYPE_TO_GROUP.keys()) |
| 569 | + return types |
0 commit comments