diff --git a/torchrec/metrics/metrics_config.py b/torchrec/metrics/metrics_config.py index c8325aa80..a971019ce 100644 --- a/torchrec/metrics/metrics_config.py +++ b/torchrec/metrics/metrics_config.py @@ -87,7 +87,6 @@ class RecComputeMode(Enum): FUSED_TASKS_COMPUTATION = 1 UNFUSED_TASKS_COMPUTATION = 2 - FUSED_TASKS_AND_STATES_COMPUTATION = 3 _DEFAULT_WINDOW_SIZE = 10_000_000 diff --git a/torchrec/metrics/precision_session.py b/torchrec/metrics/precision_session.py index daa4864fc..bc44fce45 100644 --- a/torchrec/metrics/precision_session.py +++ b/torchrec/metrics/precision_session.py @@ -196,10 +196,7 @@ def __init__( process_group: Optional[dist.ProcessGroup] = None, **kwargs: Any, ) -> None: - if compute_mode in [ - RecComputeMode.FUSED_TASKS_COMPUTATION, - RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - ]: + if compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION: raise RecMetricException( "Fused computation is not supported for precision session-level metrics" ) diff --git a/torchrec/metrics/rec_metric.py b/torchrec/metrics/rec_metric.py index 0e9e55887..fd6a51b1a 100644 --- a/torchrec/metrics/rec_metric.py +++ b/torchrec/metrics/rec_metric.py @@ -134,19 +134,13 @@ def __init__( window_size: int, compute_on_all_ranks: bool = False, should_validate_update: bool = False, - fuse_state_tensors: bool = False, process_group: Optional[dist.ProcessGroup] = None, fused_update_limit: int = 0, allow_missing_label_with_zero_weight: bool = False, *args: Any, **kwargs: Any, ) -> None: - super().__init__( - process_group=process_group, - fuse_state_tensors=fuse_state_tensors, - *args, - **kwargs, - ) + super().__init__(process_group=process_group, *args, **kwargs) self._my_rank = my_rank self._n_tasks = n_tasks @@ -347,11 +341,7 @@ def __init__( # TODO(stellaya): consider to inherit from TorchMetrics.Metric or # TorchMetrics.MetricCollection. if ( - compute_mode - in [ - RecComputeMode.FUSED_TASKS_COMPUTATION, - RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - ] + compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION and fused_update_limit > 0 ): raise ValueError( @@ -386,10 +376,7 @@ def __init__( f"Local window size must be larger than batch size. Got local window size {self._window_size} and batch size {self._batch_size}." ) - if compute_mode in [ - RecComputeMode.FUSED_TASKS_COMPUTATION, - RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - ]: + if compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION: task_per_metric = len(self._tasks) self._tasks_iter = self._fused_tasks_iter else: @@ -398,11 +385,7 @@ def __init__( for task_config in ( [self._tasks] - if compute_mode - in [ - RecComputeMode.FUSED_TASKS_COMPUTATION, - RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - ] + if compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION else self._tasks ): # pyre-ignore @@ -411,16 +394,13 @@ def __init__( # according to https://github.com/python/mypy/issues/3048. # pyre-fixme[45]: Cannot instantiate abstract class `RecMetricCoputation`. metric_computation = self._computation_class( - my_rank=my_rank, - batch_size=batch_size, - n_tasks=task_per_metric, - window_size=self._window_size, - compute_on_all_ranks=compute_on_all_ranks, - should_validate_update=self._should_validate_update, - fuse_state_tensors=( - compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION - ), - process_group=process_group, + my_rank, + batch_size, + task_per_metric, + self._window_size, + compute_on_all_ranks, + self._should_validate_update, + process_group, **{**kwargs, **self._get_task_kwargs(task_config)}, ) required_inputs = self._get_task_required_inputs(task_config) @@ -547,10 +527,7 @@ def _update( **kwargs: Dict[str, Any], ) -> None: with torch.no_grad(): - if self._compute_mode in [ - RecComputeMode.FUSED_TASKS_COMPUTATION, - RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - ]: + if self._compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION: task_names = [task.name for task in self._tasks] if not isinstance(predictions, torch.Tensor): diff --git a/torchrec/metrics/recall_session.py b/torchrec/metrics/recall_session.py index 3733e472d..9b93ec7ac 100644 --- a/torchrec/metrics/recall_session.py +++ b/torchrec/metrics/recall_session.py @@ -235,10 +235,7 @@ def __init__( process_group: Optional[dist.ProcessGroup] = None, **kwargs: Any, ) -> None: - if compute_mode in [ - RecComputeMode.FUSED_TASKS_COMPUTATION, - RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - ]: + if compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION: raise RecMetricException( "Fused computation is not supported for recall session-level metrics" ) diff --git a/torchrec/metrics/test_utils/__init__.py b/torchrec/metrics/test_utils/__init__.py index 0a1085195..46b37d103 100644 --- a/torchrec/metrics/test_utils/__init__.py +++ b/torchrec/metrics/test_utils/__init__.py @@ -291,10 +291,7 @@ def get_target_rec_metric_value( labels, predictions, weights, _ = parse_task_model_outputs( tasks, model_outs[i] ) - if target_compute_mode in [ - RecComputeMode.FUSED_TASKS_COMPUTATION, - RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - ]: + if target_compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION: labels = torch.stack(list(labels.values())) predictions = torch.stack(list(predictions.values())) weights = torch.stack(list(weights.values())) diff --git a/torchrec/metrics/tests/test_accuracy.py b/torchrec/metrics/tests/test_accuracy.py index fa46b3e87..0c4814825 100644 --- a/torchrec/metrics/tests/test_accuracy.py +++ b/torchrec/metrics/tests/test_accuracy.py @@ -52,7 +52,7 @@ class AccuracyMetricTest(unittest.TestCase): clazz: Type[RecMetric] = AccuracyMetric task_name: str = "accuracy" - def test_accuracy_unfused(self) -> None: + def test_unfused_accuracy(self) -> None: rec_metric_value_test_launcher( target_clazz=AccuracyMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, @@ -66,7 +66,7 @@ def test_accuracy_unfused(self) -> None: entry_point=metric_test_helper, ) - def test_accuracy_fused_tasks(self) -> None: + def test_fused_accuracy(self) -> None: rec_metric_value_test_launcher( target_clazz=AccuracyMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, @@ -80,20 +80,6 @@ def test_accuracy_fused_tasks(self) -> None: entry_point=metric_test_helper, ) - def test_accuracy_fused_tasks_and_states(self) -> None: - rec_metric_value_test_launcher( - target_clazz=AccuracyMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - test_clazz=TestAccuracyMetric, - metric_name=AccuracyMetricTest.task_name, - task_names=["t1", "t2", "t3"], - fused_update_limit=0, - compute_on_all_ranks=False, - should_validate_update=False, - world_size=WORLD_SIZE, - entry_point=metric_test_helper, - ) - class AccuracyMetricValueTest(unittest.TestCase): r"""This set of tests verify the computation logic of accuracy in several diff --git a/torchrec/metrics/tests/test_cali_free_ne.py b/torchrec/metrics/tests/test_cali_free_ne.py index 328dd7931..968f02677 100644 --- a/torchrec/metrics/tests/test_cali_free_ne.py +++ b/torchrec/metrics/tests/test_cali_free_ne.py @@ -89,7 +89,7 @@ def test_cali_free_ne_unfused(self) -> None: entry_point=metric_test_helper, ) - def test_cali_free_ne_fused_tasks(self) -> None: + def test_cali_free_ne_fused(self) -> None: rec_metric_value_test_launcher( target_clazz=CaliFreeNEMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, @@ -103,21 +103,7 @@ def test_cali_free_ne_fused_tasks(self) -> None: entry_point=metric_test_helper, ) - def test_cali_free_ne_fused_tasks_and_states(self) -> None: - rec_metric_value_test_launcher( - target_clazz=CaliFreeNEMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - test_clazz=TestCaliFreeNEMetric, - metric_name=CaliFreeNEMetricTest.task_name, - task_names=["t1", "t2", "t3"], - fused_update_limit=0, - compute_on_all_ranks=False, - should_validate_update=False, - world_size=WORLD_SIZE, - entry_point=metric_test_helper, - ) - - def test_cali_free_ne_update_unfused(self) -> None: + def test_cali_free_ne_update_fused(self) -> None: rec_metric_value_test_launcher( target_clazz=CaliFreeNEMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, diff --git a/torchrec/metrics/tests/test_calibration.py b/torchrec/metrics/tests/test_calibration.py index 6a2304485..d2f32a78c 100644 --- a/torchrec/metrics/tests/test_calibration.py +++ b/torchrec/metrics/tests/test_calibration.py @@ -52,7 +52,7 @@ class CalibrationMetricTest(unittest.TestCase): clazz: Type[RecMetric] = CalibrationMetric task_name: str = "calibration" - def test_calibration_unfused(self) -> None: + def test_unfused_calibration(self) -> None: rec_metric_value_test_launcher( target_clazz=CalibrationMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, @@ -66,7 +66,7 @@ def test_calibration_unfused(self) -> None: entry_point=metric_test_helper, ) - def test_calibration_fused_tasks(self) -> None: + def test_fused_calibration(self) -> None: rec_metric_value_test_launcher( target_clazz=CalibrationMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, @@ -80,20 +80,6 @@ def test_calibration_fused_tasks(self) -> None: entry_point=metric_test_helper, ) - def test_calibration_fused_tasks_and_states(self) -> None: - rec_metric_value_test_launcher( - target_clazz=CalibrationMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - test_clazz=TestCalibrationMetric, - metric_name=CalibrationMetricTest.task_name, - task_names=["t1", "t2", "t3"], - fused_update_limit=0, - compute_on_all_ranks=False, - should_validate_update=False, - world_size=WORLD_SIZE, - entry_point=metric_test_helper, - ) - class CalibrationGPUSyncTest(unittest.TestCase): clazz: Type[RecMetric] = CalibrationMetric diff --git a/torchrec/metrics/tests/test_ctr.py b/torchrec/metrics/tests/test_ctr.py index efd45752c..faaa865e3 100644 --- a/torchrec/metrics/tests/test_ctr.py +++ b/torchrec/metrics/tests/test_ctr.py @@ -46,7 +46,7 @@ class CTRMetricTest(unittest.TestCase): clazz: Type[RecMetric] = CTRMetric task_name: str = "ctr" - def test_ctr_unfused(self) -> None: + def test_unfused_ctr(self) -> None: rec_metric_value_test_launcher( target_clazz=CTRMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, @@ -60,7 +60,7 @@ def test_ctr_unfused(self) -> None: entry_point=metric_test_helper, ) - def test_ctr_fused_tasks(self) -> None: + def test_fused_ctr(self) -> None: rec_metric_value_test_launcher( target_clazz=CTRMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, @@ -74,20 +74,6 @@ def test_ctr_fused_tasks(self) -> None: entry_point=metric_test_helper, ) - def test_ctr_fused_tasks_and_states(self) -> None: - rec_metric_value_test_launcher( - target_clazz=CTRMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - test_clazz=TestCTRMetric, - metric_name=CTRMetricTest.task_name, - task_names=["t1", "t2", "t3"], - fused_update_limit=0, - compute_on_all_ranks=False, - should_validate_update=False, - world_size=WORLD_SIZE, - entry_point=metric_test_helper, - ) - class CTRGPUSyncTest(unittest.TestCase): clazz: Type[RecMetric] = CTRMetric diff --git a/torchrec/metrics/tests/test_hindsight_target_pr.py b/torchrec/metrics/tests/test_hindsight_target_pr.py index 5cc9e406d..2fd9102c8 100644 --- a/torchrec/metrics/tests/test_hindsight_target_pr.py +++ b/torchrec/metrics/tests/test_hindsight_target_pr.py @@ -126,7 +126,7 @@ class TestHindsightTargetPRMetricTest(unittest.TestCase): precision_task_name: str = "hindsight_target_precision" recall_task_name: str = "hindsight_target_recall" - def test_hindsight_target_precision_unfused(self) -> None: + def test_unfused_hindsight_target_precision(self) -> None: rec_metric_value_test_launcher( target_clazz=HindsightTargetPRMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, @@ -140,7 +140,7 @@ def test_hindsight_target_precision_unfused(self) -> None: entry_point=metric_test_helper, ) - def test_hindsight_target_recall_unfused(self) -> None: + def test_unfused_hindsight_target_recall(self) -> None: rec_metric_value_test_launcher( target_clazz=HindsightTargetPRMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, diff --git a/torchrec/metrics/tests/test_mse.py b/torchrec/metrics/tests/test_mse.py index 2e4dec541..1827bcc44 100644 --- a/torchrec/metrics/tests/test_mse.py +++ b/torchrec/metrics/tests/test_mse.py @@ -70,7 +70,7 @@ class MSEMetricTest(unittest.TestCase): task_name: str = "mse" rmse_task_name: str = "rmse" - def test_mse_unfused(self) -> None: + def test_unfused_mse(self) -> None: rec_metric_value_test_launcher( target_clazz=MSEMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, @@ -84,7 +84,7 @@ def test_mse_unfused(self) -> None: entry_point=metric_test_helper, ) - def test_mse_fused_tasks(self) -> None: + def test_fused_mse(self) -> None: rec_metric_value_test_launcher( target_clazz=MSEMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, @@ -98,21 +98,7 @@ def test_mse_fused_tasks(self) -> None: entry_point=metric_test_helper, ) - def test_mse_fused_tasks_and_states(self) -> None: - rec_metric_value_test_launcher( - target_clazz=MSEMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - test_clazz=TestMSEMetric, - metric_name=MSEMetricTest.task_name, - task_names=["t1", "t2", "t3"], - fused_update_limit=0, - compute_on_all_ranks=False, - should_validate_update=False, - world_size=WORLD_SIZE, - entry_point=metric_test_helper, - ) - - def test_rmse_unfused(self) -> None: + def test_unfused_rmse(self) -> None: rec_metric_value_test_launcher( target_clazz=MSEMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, @@ -126,7 +112,7 @@ def test_rmse_unfused(self) -> None: entry_point=metric_test_helper, ) - def test_rmse_fused_tasks(self) -> None: + def test_fused_rmse(self) -> None: rec_metric_value_test_launcher( target_clazz=MSEMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, @@ -140,20 +126,6 @@ def test_rmse_fused_tasks(self) -> None: entry_point=metric_test_helper, ) - def test_rmse_fused_tasks_and_states(self) -> None: - rec_metric_value_test_launcher( - target_clazz=MSEMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - test_clazz=TestRMSEMetric, - metric_name=MSEMetricTest.rmse_task_name, - task_names=["t1", "t2", "t3"], - fused_update_limit=0, - compute_on_all_ranks=False, - should_validate_update=False, - world_size=WORLD_SIZE, - entry_point=metric_test_helper, - ) - class MSEGPUSyncTest(unittest.TestCase): clazz: Type[RecMetric] = MSEMetric diff --git a/torchrec/metrics/tests/test_ndcg.py b/torchrec/metrics/tests/test_ndcg.py index 9948e9a02..add23b885 100644 --- a/torchrec/metrics/tests/test_ndcg.py +++ b/torchrec/metrics/tests/test_ndcg.py @@ -703,43 +703,3 @@ def test_multitask_exp(self) -> None: equal_nan=True, msg=f"Actual: {actual_metric}, Expected: {expected_metric}", ) - - def test_multitask_exp_fused_tasks_and_states(self) -> None: - """ - Test NDCG with multiple tasks. - """ - model_output = get_test_case_multitask() - metric = self.generate_metric( - world_size=WORLD_SIZE, - my_rank=0, - batch_size=BATCH_SIZE, - tasks=[DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2], - exponential_gain=True, - session_key=SESSION_KEY, - compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - ) - - metric.update( - predictions=model_output["predictions"], - labels=model_output["labels"], - weights=model_output["weights"], - required_inputs={SESSION_KEY: model_output["session_ids"]}, - ) - output = metric.compute() - actual_metric = torch.stack( - [ - output[f"ndcg-{task.name}|lifetime_ndcg"] - for task in [DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2] - ] - ) - expected_metric = model_output["expected_ndcg_exp"] - - torch.testing.assert_close( - actual_metric, - expected_metric, - atol=1e-4, - rtol=1e-4, - check_dtype=False, - equal_nan=True, - msg=f"Actual: {actual_metric}, Expected: {expected_metric}", - ) diff --git a/torchrec/metrics/tests/test_ne.py b/torchrec/metrics/tests/test_ne.py index 4a5a5359d..bd1db6ab5 100644 --- a/torchrec/metrics/tests/test_ne.py +++ b/torchrec/metrics/tests/test_ne.py @@ -117,7 +117,7 @@ def test_ne_unfused(self) -> None: entry_point=metric_test_helper, ) - def test_ne_fused_tasks(self) -> None: + def test_ne_fused(self) -> None: rec_metric_value_test_launcher( target_clazz=NEMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, @@ -131,20 +131,6 @@ def test_ne_fused_tasks(self) -> None: entry_point=metric_test_helper, ) - def test_ne_fused_tasks_and_states(self) -> None: - rec_metric_value_test_launcher( - target_clazz=NEMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - test_clazz=TestNEMetric, - metric_name=NEMetricTest.task_name, - task_names=["t1", "t2", "t3"], - fused_update_limit=0, - compute_on_all_ranks=False, - should_validate_update=False, - world_size=WORLD_SIZE, - entry_point=metric_test_helper, - ) - def test_ne_update_fused(self) -> None: rec_metric_value_test_launcher( target_clazz=NEMetric, @@ -207,7 +193,7 @@ def test_logloss_unfused(self) -> None: entry_point=self._logloss_metric_test_helper, ) - def test_logloss_fused_tasks(self) -> None: + def test_logloss_fused(self) -> None: rec_metric_value_test_launcher( target_clazz=NEMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, @@ -221,20 +207,6 @@ def test_logloss_fused_tasks(self) -> None: entry_point=self._logloss_metric_test_helper, ) - def test_logloss_fused_tasks_and_states(self) -> None: - rec_metric_value_test_launcher( - target_clazz=NEMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - metric_name="logloss", - test_clazz=TestLoglossMetric, - task_names=["t1", "t2", "t3"], - fused_update_limit=0, - compute_on_all_ranks=False, - should_validate_update=False, - world_size=WORLD_SIZE, - entry_point=self._logloss_metric_test_helper, - ) - def test_logloss_update_fused(self) -> None: rec_metric_value_test_launcher( target_clazz=NEMetric, diff --git a/torchrec/metrics/tests/test_precision.py b/torchrec/metrics/tests/test_precision.py index 8a58485f6..977eed95b 100644 --- a/torchrec/metrics/tests/test_precision.py +++ b/torchrec/metrics/tests/test_precision.py @@ -53,7 +53,7 @@ class PrecisionMetricTest(unittest.TestCase): target_clazz: Type[RecMetric] = PrecisionMetric task_name: str = "precision" - def test_precision_unfused(self) -> None: + def test_unfused_precision(self) -> None: rec_metric_value_test_launcher( target_clazz=PrecisionMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, @@ -67,7 +67,7 @@ def test_precision_unfused(self) -> None: entry_point=metric_test_helper, ) - def test_precision_fused_tasks(self) -> None: + def test_fused_precision(self) -> None: rec_metric_value_test_launcher( target_clazz=PrecisionMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, @@ -81,20 +81,6 @@ def test_precision_fused_tasks(self) -> None: entry_point=metric_test_helper, ) - def test_precision_fused_tasks_and_states(self) -> None: - rec_metric_value_test_launcher( - target_clazz=PrecisionMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - test_clazz=TestPrecisionMetric, - metric_name=PrecisionMetricTest.task_name, - task_names=["t1", "t2", "t3"], - fused_update_limit=0, - compute_on_all_ranks=False, - should_validate_update=False, - world_size=WORLD_SIZE, - entry_point=metric_test_helper, - ) - class PrecisionMetricValueTest(unittest.TestCase): r"""This set of tests verify the computation logic of precision in several diff --git a/torchrec/metrics/tests/test_recall.py b/torchrec/metrics/tests/test_recall.py index d09faf464..eb53048b6 100644 --- a/torchrec/metrics/tests/test_recall.py +++ b/torchrec/metrics/tests/test_recall.py @@ -53,7 +53,7 @@ class RecallMetricTest(unittest.TestCase): clazz: Type[RecMetric] = RecallMetric task_name: str = "recall" - def test_recall_unfused(self) -> None: + def test_unfused_recall(self) -> None: rec_metric_value_test_launcher( target_clazz=RecallMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, @@ -67,7 +67,7 @@ def test_recall_unfused(self) -> None: entry_point=metric_test_helper, ) - def test_recall_fused_tasks(self) -> None: + def test_fused_recall(self) -> None: rec_metric_value_test_launcher( target_clazz=RecallMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, @@ -81,20 +81,6 @@ def test_recall_fused_tasks(self) -> None: entry_point=metric_test_helper, ) - def test_recall_fused_tasks_and_states(self) -> None: - rec_metric_value_test_launcher( - target_clazz=RecallMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - test_clazz=TestRecallMetric, - metric_name=RecallMetricTest.task_name, - task_names=["t1", "t2", "t3"], - fused_update_limit=0, - compute_on_all_ranks=False, - should_validate_update=False, - world_size=WORLD_SIZE, - entry_point=metric_test_helper, - ) - class RecallMetricValueTest(unittest.TestCase): r"""This set of tests verify the computation logic of recall in several diff --git a/torchrec/metrics/tests/test_serving_calibration.py b/torchrec/metrics/tests/test_serving_calibration.py index 810a69bfb..b251ba202 100644 --- a/torchrec/metrics/tests/test_serving_calibration.py +++ b/torchrec/metrics/tests/test_serving_calibration.py @@ -52,7 +52,7 @@ class ServingCalibrationMetricTest(unittest.TestCase): clazz: Type[RecMetric] = ServingCalibrationMetric task_name: str = "calibration" - def test_calibration_unfused(self) -> None: + def test_unfused_calibration(self) -> None: rec_metric_value_test_launcher( target_clazz=ServingCalibrationMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, @@ -66,7 +66,7 @@ def test_calibration_unfused(self) -> None: entry_point=metric_test_helper, ) - def test_calibration_fused_tasks(self) -> None: + def test_fused_calibration(self) -> None: rec_metric_value_test_launcher( target_clazz=ServingCalibrationMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, @@ -80,20 +80,6 @@ def test_calibration_fused_tasks(self) -> None: entry_point=metric_test_helper, ) - def test_calibration_fused_tasks_and_states(self) -> None: - rec_metric_value_test_launcher( - target_clazz=ServingCalibrationMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - test_clazz=TestServingCalibrationMetric, - metric_name=ServingCalibrationMetricTest.task_name, - task_names=["t1", "t2", "t3"], - fused_update_limit=0, - compute_on_all_ranks=False, - should_validate_update=False, - world_size=WORLD_SIZE, - entry_point=metric_test_helper, - ) - # TODO - Serving Calibration uses Calibration naming inconsistently class ServingCalibrationGPUSyncTest(unittest.TestCase): diff --git a/torchrec/metrics/tests/test_tower_qps.py b/torchrec/metrics/tests/test_tower_qps.py index 7dd91010d..97050e09e 100644 --- a/torchrec/metrics/tests/test_tower_qps.py +++ b/torchrec/metrics/tests/test_tower_qps.py @@ -161,7 +161,7 @@ class TowerQPSMetricTest(unittest.TestCase): ) update_wrapper(_test_tower_qps, metric_test_helper) - def test_tower_qps_during_warmup_unfused(self) -> None: + def test_unfused_tower_qps_during_warmup(self) -> None: rec_metric_value_test_launcher( target_clazz=TowerQPSMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, @@ -175,7 +175,7 @@ def test_tower_qps_during_warmup_unfused(self) -> None: entry_point=self._test_tower_qps, ) - def test_tower_qps_unfused(self) -> None: + def test_unfused_tower_qps(self) -> None: rec_metric_value_test_launcher( target_clazz=TowerQPSMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, @@ -190,7 +190,7 @@ def test_tower_qps_unfused(self) -> None: test_nsteps=DURING_WARMUP_NSTEPS, ) - def test_tower_qps_fused_tasks(self) -> None: + def test_fused_tower_qps(self) -> None: rec_metric_value_test_launcher( target_clazz=TowerQPSMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, @@ -205,22 +205,7 @@ def test_tower_qps_fused_tasks(self) -> None: test_nsteps=AFTER_WARMUP_NSTEPS, ) - def test_tower_qps_fused_tasks_and_states(self) -> None: - rec_metric_value_test_launcher( - target_clazz=TowerQPSMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - test_clazz=TestTowerQPSMetric, - metric_name=TowerQPSMetricTest.task_names, - task_names=["t1", "t2", "t3"], - fused_update_limit=0, - compute_on_all_ranks=False, - should_validate_update=False, - world_size=WORLD_SIZE, - entry_point=self._test_tower_qps, - test_nsteps=AFTER_WARMUP_NSTEPS, - ) - - def test_check_update_tower_qps_unfused(self) -> None: + def test_unfused_check_update_tower_qps(self) -> None: rec_metric_value_test_launcher( target_clazz=TowerQPSMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, @@ -235,7 +220,7 @@ def test_check_update_tower_qps_unfused(self) -> None: test_nsteps=AFTER_WARMUP_NSTEPS, ) - def test_check_update_tower_qps_fused_tasks(self) -> None: + def test_fused_check_update_tower_qps(self) -> None: rec_metric_value_test_launcher( target_clazz=TowerQPSMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, diff --git a/torchrec/metrics/tests/test_unweighted_ne.py b/torchrec/metrics/tests/test_unweighted_ne.py index 5a18178d0..d80c10ae6 100644 --- a/torchrec/metrics/tests/test_unweighted_ne.py +++ b/torchrec/metrics/tests/test_unweighted_ne.py @@ -88,7 +88,7 @@ def test_unweighted_ne_unfused(self) -> None: entry_point=metric_test_helper, ) - def test_unweighted_ne_fused_tasks(self) -> None: + def test_unweighted_ne_fused(self) -> None: rec_metric_value_test_launcher( target_clazz=UnweightedNEMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, @@ -102,21 +102,7 @@ def test_unweighted_ne_fused_tasks(self) -> None: entry_point=metric_test_helper, ) - def test_unweighted_ne_fused_tasks_and_states(self) -> None: - rec_metric_value_test_launcher( - target_clazz=UnweightedNEMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - test_clazz=TestUnweightedNEMetric, - metric_name=UnweightedNEMetricTest.task_name, - task_names=["t1", "t2", "t3"], - fused_update_limit=0, - compute_on_all_ranks=False, - should_validate_update=False, - world_size=WORLD_SIZE, - entry_point=metric_test_helper, - ) - - def test_unweighted_ne_update_unfused(self) -> None: + def test_unweighted_ne_update_fused(self) -> None: rec_metric_value_test_launcher( target_clazz=UnweightedNEMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, diff --git a/torchrec/metrics/tests/test_weighted_avg.py b/torchrec/metrics/tests/test_weighted_avg.py index 226c06748..40009c28d 100644 --- a/torchrec/metrics/tests/test_weighted_avg.py +++ b/torchrec/metrics/tests/test_weighted_avg.py @@ -58,7 +58,7 @@ def test_weighted_avg_unfused(self) -> None: entry_point=metric_test_helper, ) - def test_weighted_avg_fused_tasks(self) -> None: + def test_weighted_avg_fused(self) -> None: rec_metric_value_test_launcher( target_clazz=WeightedAvgMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, @@ -72,21 +72,7 @@ def test_weighted_avg_fused_tasks(self) -> None: entry_point=metric_test_helper, ) - def test_weighted_avg_fused_tasks_and_states(self) -> None: - rec_metric_value_test_launcher( - target_clazz=WeightedAvgMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - test_clazz=TestWeightedAvgMetric, - metric_name=WeightedAvgMetricTest.task_name, - task_names=["t1", "t2", "t3"], - fused_update_limit=0, - compute_on_all_ranks=False, - should_validate_update=False, - world_size=WORLD_SIZE, - entry_point=metric_test_helper, - ) - - def test_weighted_avg_update_unfused(self) -> None: + def test_weighted_avg_update_fused(self) -> None: rec_metric_value_test_launcher( target_clazz=WeightedAvgMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, diff --git a/torchrec/metrics/tower_qps.py b/torchrec/metrics/tower_qps.py index 4411dcba4..4a8b7681c 100644 --- a/torchrec/metrics/tower_qps.py +++ b/torchrec/metrics/tower_qps.py @@ -222,10 +222,7 @@ def update( **kwargs: Dict[str, Any], ) -> None: with torch.no_grad(): - if self._compute_mode in [ - RecComputeMode.FUSED_TASKS_COMPUTATION, - RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, - ]: + if self._compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION: if not isinstance(labels, torch.Tensor): raise RecMetricException( "Fused computation only support where 'labels' is a tensor"