Skip to content

Commit f001f48

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
mark metricmodule methods as experimental
Summary: tsia Reviewed By: kausv Differential Revision: D75014609
1 parent e76130c commit f001f48

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

torchrec/metrics/metric_module.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from torchrec.metrics.unweighted_ne import UnweightedNEMetric
6969
from torchrec.metrics.weighted_avg import WeightedAvgMetric
7070
from torchrec.metrics.xauc import XAUCMetric
71+
from torchrec.utils.experimental import experimental
7172

7273

7374
logger: logging.Logger = logging.getLogger(__name__)
@@ -395,6 +396,7 @@ def _get_metric_states(
395396

396397
return state_aggregated
397398

399+
@experimental
398400
def get_pre_compute_states(
399401
self, pg: Union[dist.ProcessGroup, DeviceMesh], reduce_metrics: bool = True
400402
) -> Dict[str, Dict[str, Dict[str, torch.Tensor]]]:
@@ -438,6 +440,7 @@ def get_pre_compute_states(
438440

439441
return aggregated_states
440442

443+
@experimental
441444
def load_pre_compute_states(
442445
self, source: Dict[str, Dict[str, Dict[str, torch.Tensor]]]
443446
) -> None:

0 commit comments

Comments
 (0)