diff --git a/torchrec/metrics/tower_qps.py b/torchrec/metrics/tower_qps.py index 4411dcba4..de79eb4d1 100644 --- a/torchrec/metrics/tower_qps.py +++ b/torchrec/metrics/tower_qps.py @@ -227,9 +227,14 @@ def update( RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, ]: if not isinstance(labels, torch.Tensor): - raise RecMetricException( - "Fused computation only support where 'labels' is a tensor" - ) + try: + labels = torch.stack( + [labels[task.name] for task in self._tasks] + ) + except Exception as e: + raise RecMetricException( + f"Failed to convert labels to tensor for fused computation: {e}" + ) labels = labels.view(-1, self._batch_size) if self._should_validate_update: # Set the default value to be all True. When weights is None, it's considered @@ -241,9 +246,14 @@ def update( ) if weights is not None: if not isinstance(weights, torch.Tensor): - raise RecMetricException( - "Fused computation only support where 'weights' is a tensor" - ) + try: + weights = torch.stack( + [weights[task.name] for task in self._tasks] + ) + except Exception as e: + raise RecMetricException( + f"Failed to convert weights to tensor for fused computation: {e}" + ) has_valid_weights = torch.gt( torch.count_nonzero( weights.view(-1, self._batch_size), dim=-1