Skip to content

Commit 67ebc8c

Browse files
Jasper Shanfacebook-github-bot
Jasper Shan
authored andcommitted
Refactoring ITEP / PTP Pruning Scuba Logger [1/N] (#2986)
Summary: Pull Request resolved: #2986 Refactoring ITEPLogger -> PruningLogger as we are making the logger more generic Torchrec diff Reviewed By: AKhazane Differential Revision: D75108237 fbshipit-source-id: 12e67f08081c664f8bc6d27340ec8a808591c893
1 parent 77f082d commit 67ebc8c

File tree

3 files changed

+8
-16
lines changed

3 files changed

+8
-16
lines changed

torchrec/distributed/itep_embeddingbag.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ def __init__(
134134
pruning_interval=module._itep_module.pruning_interval,
135135
enable_pruning=module._itep_module.enable_pruning,
136136
pg=env.process_group,
137-
itep_logger=module._itep_module.itep_logger,
138137
)
139138

140139
def prefetch(

torchrec/modules/itep_modules.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torchrec.distributed.embedding_types import ShardedEmbeddingTable, ShardingType
2222
from torchrec.distributed.types import Shard, ShardedTensor, ShardedTensorMetadata
2323
from torchrec.modules.embedding_modules import reorder_inverse_indices
24-
from torchrec.modules.itep_logger import ITEPLogger, ITEPLoggerDefault
24+
from torchrec.modules.pruning_logger import PruningLogger, PruningLoggerDefault
2525

2626
from torchrec.sparse.jagged_tensor import _pin_and_move, _to_offsets, KeyedJaggedTensor
2727

@@ -73,7 +73,7 @@ def __init__(
7373
pruning_interval: int = 1001, # Default pruning interval 1001 iterations
7474
pg: Optional[dist.ProcessGroup] = None,
7575
table_name_to_sharding_type: Optional[Dict[str, str]] = None,
76-
itep_logger: Optional[ITEPLogger] = None,
76+
scuba_logger: Optional[PruningLogger] = None,
7777
) -> None:
7878
super(GenericITEPModule, self).__init__()
7979

@@ -90,10 +90,10 @@ def __init__(
9090
)
9191
self.table_name_to_sharding_type = table_name_to_sharding_type
9292

93-
self.itep_logger: ITEPLogger = (
94-
itep_logger if itep_logger is not None else ITEPLoggerDefault()
93+
self.scuba_logger: PruningLogger = (
94+
scuba_logger if scuba_logger is not None else PruningLoggerDefault()
9595
)
96-
self.itep_logger.log_run_info()
96+
self.scuba_logger.log_run_info()
9797

9898
# Map each feature to a physical address_lookup/row_util buffer
9999
self.feature_table_map: Dict[str, int] = {}
@@ -168,13 +168,6 @@ def print_itep_eviction_stats(
168168
# in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
169169
pass
170170

171-
self.itep_logger.log_table_eviction_info(
172-
iteration=None,
173-
rank=None,
174-
table_to_sizes_mapping=table_to_sizes_mapping,
175-
eviction_tables=logged_eviction_mapping,
176-
)
177-
178171
# Print the sorted mapping
179172
logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}")
180173

torchrec/modules/itep_logger.py renamed to torchrec/modules/pruning_logger.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
logger: logging.Logger = logging.getLogger(__name__)
1313

1414

15-
class ITEPLogger(ABC):
15+
class PruningLogger(ABC):
1616
@abstractmethod
1717
def log_table_eviction_info(
1818
self,
@@ -30,7 +30,7 @@ def log_run_info(
3030
pass
3131

3232

33-
class ITEPLoggerDefault(ITEPLogger):
33+
class PruningLoggerDefault(PruningLogger):
3434
"""
3535
noop logger as a default
3636
"""
@@ -39,7 +39,7 @@ def __init__(
3939
self,
4040
) -> None:
4141
"""
42-
Initialize ITEPLoggerScuba.
42+
Initialize PruningScubaLogger.
4343
"""
4444
pass
4545

0 commit comments

Comments
 (0)