|
3 | 3 | import torch
|
4 | 4 | from commode_utils.losses import SequenceCrossEntropyLoss
|
5 | 5 | from commode_utils.metrics import SequentialF1Score, ClassificationMetrics
|
| 6 | +from commode_utils.metrics.chrF import ChrF |
6 | 7 | from commode_utils.modules import LSTMDecoderStep, Decoder
|
7 | 8 | from omegaconf import DictConfig
|
8 | 9 | from pytorch_lightning import LightningModule
|
@@ -41,6 +42,10 @@ def __init__(
|
41 | 42 | f"{holdout}_f1": SequentialF1Score(pad_idx=self.__pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx)
|
42 | 43 | for holdout in ["train", "val", "test"]
|
43 | 44 | }
|
| 45 | + id2label = {v: k for k, v in vocabulary.label_to_id.items()} |
| 46 | + metrics.update( |
| 47 | + {f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self.__pad_idx, eos_idx]) for holdout in ["val", "test"]} |
| 48 | + ) |
44 | 49 | self.__metrics = MetricCollection(metrics)
|
45 | 50 |
|
46 | 51 | self._encoder = self._get_encoder(model_config)
|
@@ -102,18 +107,18 @@ def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict:
|
102 | 107 | target_sequence = batch.labels if step == "train" else None
|
103 | 108 | # [seq length; batch size; vocab size]
|
104 | 109 | logits, _ = self.logits_from_batch(batch, target_sequence)
|
105 |
| - loss = self.__loss(logits[1:], batch.labels[1:]) |
| 110 | + result = {f"{step}/loss": self.__loss(logits[1:], batch.labels[1:])} |
106 | 111 |
|
107 | 112 | with torch.no_grad():
|
108 | 113 | prediction = logits.argmax(-1)
|
109 | 114 | metric: ClassificationMetrics = self.__metrics[f"{step}_f1"](prediction, batch.labels)
|
| 115 | + result.update( |
| 116 | + {f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall} |
| 117 | + ) |
| 118 | + if step != "train": |
| 119 | + result[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"](prediction, batch.labels) |
110 | 120 |
|
111 |
| - return { |
112 |
| - f"{step}/loss": loss, |
113 |
| - f"{step}/f1": metric.f1_score, |
114 |
| - f"{step}/precision": metric.precision, |
115 |
| - f"{step}/recall": metric.recall, |
116 |
| - } |
| 121 | + return result |
117 | 122 |
|
118 | 123 | def training_step(self, batch: BatchedLabeledPathContext, batch_idx: int) -> Dict: # type: ignore
|
119 | 124 | result = self._shared_step(batch, "train")
|
@@ -143,6 +148,9 @@ def _shared_epoch_end(self, step_outputs: EPOCH_OUTPUT, step: str):
|
143 | 148 | f"{step}/recall": metric.recall,
|
144 | 149 | }
|
145 | 150 | self.__metrics[f"{step}_f1"].reset()
|
| 151 | + if step != "train": |
| 152 | + log[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"].compute() |
| 153 | + self.__metrics[f"{step}_chrf"].reset() |
146 | 154 | self.log_dict(log, on_step=False, on_epoch=True)
|
147 | 155 |
|
148 | 156 | def training_epoch_end(self, step_outputs: EPOCH_OUTPUT):
|
|
0 commit comments