Skip to content

Commit a0d86b3

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
mark metricmodule methods as experimental (#2983)
Summary: tsia Reviewed By: xunnanxu, kausv Differential Revision: D75014609
1 parent e0d4b6c commit a0d86b3

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

torchrec/metrics/metric_module.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from torchrec.metrics.unweighted_ne import UnweightedNEMetric
6969
from torchrec.metrics.weighted_avg import WeightedAvgMetric
7070
from torchrec.metrics.xauc import XAUCMetric
71+
from torchrec.utils.experimental import experimental
7172

7273

7374
logger: logging.Logger = logging.getLogger(__name__)
@@ -394,6 +395,7 @@ def _get_metric_states(
394395

395396
return state_aggregated
396397

398+
@experimental
397399
def get_pre_compute_states(
398400
self, pg: Optional[Union[dist.ProcessGroup, DeviceMesh]] = None
399401
) -> Dict[str, Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]]]:
@@ -442,6 +444,7 @@ def get_pre_compute_states(
442444

443445
return aggregated_states
444446

447+
@experimental
445448
def load_pre_compute_states(
446449
self,
447450
source: Dict[

torchrec/metrics/tests/test_metric_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ def metric_module_gather_state(
643643
metric_module.update(test_batch)
644644

645645
computed_value = metric_module.compute()
646-
states = metric_module.get_pre_compute_states(pg=ctx.pg) # pyre-ignore[6]
646+
states = metric_module.get_pre_compute_states(pg=ctx.pg)
647647

648648
torch.distributed.barrier(ctx.pg)
649649
# Compare to computing metrics on metric module that loads from pre_compute_states

0 commit comments

Comments
 (0)