|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | # this is just a bypass for this module name collision with built-in one |
| 15 | +from copy import deepcopy |
15 | 16 | from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union |
16 | 17 |
|
17 | 18 | from torch import Tensor, nn |
@@ -167,6 +168,33 @@ def reset(self) -> None: |
167 | 168 | metric.reset() |
168 | 169 | super().reset() |
169 | 170 |
|
| 171 | + @staticmethod |
| 172 | + def _check_arg(arg: Optional[str], name: str) -> Optional[str]: |
| 173 | + if arg is None or isinstance(arg, str): |
| 174 | + return arg |
| 175 | + raise ValueError(f"Expected input `{name}` to be a string, but got {type(arg)}") |
| 176 | + |
| 177 | + def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MultitaskWrapper": |
| 178 | + """Make a copy of the metric. |
| 179 | +
|
| 180 | + Args: |
| 181 | + prefix: a string to append in front of the metric keys |
| 182 | + postfix: a string to append after the keys of the output dict. |
| 183 | +
|
| 184 | + """ |
| 185 | + multitask_copy = deepcopy(self) |
| 186 | + if prefix is not None: |
| 187 | + prefix = self._check_arg(prefix, "prefix") |
| 188 | + multitask_copy.task_metrics = nn.ModuleDict( |
| 189 | + {prefix + key: value for key, value in multitask_copy.task_metrics.items()} |
| 190 | + ) |
| 191 | + if postfix is not None: |
| 192 | + postfix = self._check_arg(postfix, "postfix") |
| 193 | + multitask_copy.task_metrics = nn.ModuleDict( |
| 194 | + {key + postfix: value for key, value in multitask_copy.task_metrics.items()} |
| 195 | + ) |
| 196 | + return multitask_copy |
| 197 | + |
170 | 198 | def plot( |
171 | 199 | self, val: Optional[Union[Dict, Sequence[Dict]]] = None, axes: Optional[Sequence[_AX_TYPE]] = None |
172 | 200 | ) -> Sequence[_PLOT_OUT_TYPE]: |
|
0 commit comments