Skip to content

Commit 4527aaf

Browse files
authored
Prefix postfix args in clone for multitask (#2330)
* feature * tests * docstring
1 parent 35bdeb6 commit 4527aaf

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

src/torchmetrics/wrappers/multitask.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# this is just a bypass for this module name collision with built-in one
15+
from copy import deepcopy
1516
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union
1617

1718
from torch import Tensor, nn
@@ -167,6 +168,33 @@ def reset(self) -> None:
167168
metric.reset()
168169
super().reset()
169170

171+
@staticmethod
172+
def _check_arg(arg: Optional[str], name: str) -> Optional[str]:
173+
if arg is None or isinstance(arg, str):
174+
return arg
175+
raise ValueError(f"Expected input `{name}` to be a string, but got {type(arg)}")
176+
177+
def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MultitaskWrapper":
178+
"""Make a copy of the metric.
179+
180+
Args:
181+
prefix: a string to append in front of the metric keys
182+
postfix: a string to append after the keys of the output dict.
183+
184+
"""
185+
multitask_copy = deepcopy(self)
186+
if prefix is not None:
187+
prefix = self._check_arg(prefix, "prefix")
188+
multitask_copy.task_metrics = nn.ModuleDict(
189+
{prefix + key: value for key, value in multitask_copy.task_metrics.items()}
190+
)
191+
if postfix is not None:
192+
postfix = self._check_arg(postfix, "postfix")
193+
multitask_copy.task_metrics = nn.ModuleDict(
194+
{key + postfix: value for key, value in multitask_copy.task_metrics.items()}
195+
)
196+
return multitask_copy
197+
170198
def plot(
171199
self, val: Optional[Union[Dict, Sequence[Dict]]] = None, axes: Optional[Sequence[_AX_TYPE]] = None
172200
) -> Sequence[_PLOT_OUT_TYPE]:

tests/unittests/wrappers/test_multitask.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,18 @@ def test_nested_multitask_wrapper():
207207
multitask_results = multitask_metrics.compute()
208208

209209
assert _dict_results_same_as_individual_results(classification_results, regression_results, multitask_results)
210+
211+
212+
def test_clone_with_prefix_and_postfix():
213+
"""Check that the clone method works with prefix and postfix arguments."""
214+
multitask_metrics = MultitaskWrapper({"Classification": BinaryAccuracy(), "Regression": MeanSquaredError()})
215+
cloned_metrics_with_prefix = multitask_metrics.clone(prefix="prefix_")
216+
cloned_metrics_with_postfix = multitask_metrics.clone(postfix="_postfix")
217+
218+
# Check if the cloned metrics have the expected keys
219+
assert set(cloned_metrics_with_prefix.task_metrics.keys()) == {"prefix_Classification", "prefix_Regression"}
220+
assert set(cloned_metrics_with_postfix.task_metrics.keys()) == {"Classification_postfix", "Regression_postfix"}
221+
222+
# Check if the cloned metrics have the expected values
223+
assert isinstance(cloned_metrics_with_prefix.task_metrics["prefix_Classification"], BinaryAccuracy)
224+
assert isinstance(cloned_metrics_with_prefix.task_metrics["prefix_Regression"], MeanSquaredError)

0 commit comments

Comments
 (0)