diff --git a/torchrec/distributed/itep_embeddingbag.py b/torchrec/distributed/itep_embeddingbag.py index 193a5c93b3..7250077a4b 100644 --- a/torchrec/distributed/itep_embeddingbag.py +++ b/torchrec/distributed/itep_embeddingbag.py @@ -134,7 +134,6 @@ def __init__( pruning_interval=module._itep_module.pruning_interval, enable_pruning=module._itep_module.enable_pruning, pg=env.process_group, - itep_logger=module._itep_module.itep_logger, ) def prefetch( diff --git a/torchrec/modules/itep_modules.py b/torchrec/modules/itep_modules.py index b743346274..a7f87a0093 100644 --- a/torchrec/modules/itep_modules.py +++ b/torchrec/modules/itep_modules.py @@ -21,7 +21,7 @@ from torchrec.distributed.embedding_types import ShardedEmbeddingTable, ShardingType from torchrec.distributed.types import Shard, ShardedTensor, ShardedTensorMetadata from torchrec.modules.embedding_modules import reorder_inverse_indices -from torchrec.modules.itep_logger import ITEPLogger, ITEPLoggerDefault +from torchrec.modules.pruning_logger import PruningLogger, PruningLoggerDefault from torchrec.sparse.jagged_tensor import _pin_and_move, _to_offsets, KeyedJaggedTensor @@ -73,7 +73,7 @@ def __init__( pruning_interval: int = 1001, # Default pruning interval 1001 iterations pg: Optional[dist.ProcessGroup] = None, table_name_to_sharding_type: Optional[Dict[str, str]] = None, - itep_logger: Optional[ITEPLogger] = None, + scuba_logger: Optional[PruningLogger] = None, ) -> None: super(GenericITEPModule, self).__init__() @@ -90,10 +90,10 @@ def __init__( ) self.table_name_to_sharding_type = table_name_to_sharding_type - self.itep_logger: ITEPLogger = ( - itep_logger if itep_logger is not None else ITEPLoggerDefault() + self.scuba_logger: PruningLogger = ( + scuba_logger if scuba_logger is not None else PruningLoggerDefault() ) - self.itep_logger.log_run_info() + self.scuba_logger.log_run_info() # Map each feature to a physical address_lookup/row_util buffer self.feature_table_map: Dict[str, int] = {} @@ -168,13 +168,6 @@ def print_itep_eviction_stats( # in dummy mode, we don't have the feature_table_map or reversed_feature_table_map pass - self.itep_logger.log_table_eviction_info( - iteration=None, - rank=None, - table_to_sizes_mapping=table_to_sizes_mapping, - eviction_tables=logged_eviction_mapping, - ) - # Print the sorted mapping logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}") diff --git a/torchrec/modules/itep_logger.py b/torchrec/modules/pruning_logger.py similarity index 92% rename from torchrec/modules/itep_logger.py rename to torchrec/modules/pruning_logger.py index fa729488a9..90d6a6b10c 100644 --- a/torchrec/modules/itep_logger.py +++ b/torchrec/modules/pruning_logger.py @@ -12,7 +12,7 @@ logger: logging.Logger = logging.getLogger(__name__) -class ITEPLogger(ABC): +class PruningLogger(ABC): @abstractmethod def log_table_eviction_info( self, @@ -30,7 +30,7 @@ def log_run_info( pass -class ITEPLoggerDefault(ITEPLogger): +class PruningLoggerDefault(PruningLogger): """ noop logger as a default """ @@ -39,7 +39,7 @@ def __init__( self, ) -> None: """ - Initialize ITEPLoggerScuba. + Initialize PruningScubaLogger. """ pass