Skip to content

Commit d4e0781

Browse files
authored
Merge pull request #64 from fschlatt/main
Fix pretty printing for skipped inference datasets
2 parents 52303b1 + d6e2315 commit d4e0781

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

lightning_ir/base/module.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
from lightning.pytorch.trainer.states import RunningStage
1515
from transformers import BatchEncoding
1616

17-
from ..data import RankBatch, RunDataset, SearchBatch, TrainBatch
18-
from ..data.dataset import IRDataset
17+
from ..data import IRDataset, RankBatch, RunDataset, SearchBatch, TrainBatch
1918
from ..loss.loss import InBatchLossFunction, LossFunction
2019
from .config import LightningIRConfig
2120
from .model import LightningIRModel, LightningIROutput
@@ -235,7 +234,8 @@ def validation_step(
235234
if self.evaluation_metrics is None:
236235
return output
237236

238-
dataset_id = self.get_dataset_id(dataloader_idx)
237+
dataset = self.get_dataset(dataloader_idx)
238+
dataset_id = str(dataloader_idx) if dataset is None else self.get_dataset_id(dataset)
239239
metrics = self.validate(output, batch)
240240
for key, value in metrics.items():
241241
key = f"{dataset_id}/{key}"
@@ -290,7 +290,7 @@ def get_dataset(self, dataloader_idx: int) -> IRDataset | None:
290290
dataloaders = [dataloaders]
291291
return dataloaders[dataloader_idx].dataset
292292

293-
def get_dataset_id(self, dataloader_idx: int) -> str:
293+
def get_dataset_id(self, dataset: IRDataset) -> str:
294294
"""Gets the dataset id from the dataloader index for logging.
295295
296296
.. _ir-datasets: https://ir-datasets.com/
@@ -300,9 +300,6 @@ def get_dataset_id(self, dataloader_idx: int) -> str:
300300
:return: path to run file, ir-datasets_ dataset id, or dataloader index
301301
:rtype: str
302302
"""
303-
dataset = self.get_dataset(dataloader_idx)
304-
if dataset is None:
305-
return str(dataloader_idx)
306303
if isinstance(dataset, RunDataset) and dataset.run_path is not None:
307304
dataset_id = dataset.run_path.name
308305
else:
@@ -420,7 +417,15 @@ def on_validation_end(self) -> None:
420417
df.columns.name = None
421418

422419
# bring into correct order when skipping inference datasets
423-
dataset_ids = [self.get_dataset_id(i) for i in range(df.shape[0])]
420+
datamodule = getattr(self.trainer, "datamodule", None)
421+
if datamodule is not None and hasattr(datamodule, "inference_datasets"):
422+
inference_datasets = datamodule.inference_datasets
423+
if len(inference_datasets) != df.shape[0]:
424+
raise ValueError(
425+
"Number of inference datasets does not match number of dataloaders. "
426+
"Check if the dataloaders are correctly configured."
427+
)
428+
dataset_ids = [self.get_dataset_id(dataset) for dataset in inference_datasets]
424429
df = df.reindex(dataset_ids)
425430

426431
trainer.print(df)

lightning_ir/data/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66

77
from .data import DocSample, IndexBatch, QuerySample, RankBatch, RankSample, SearchBatch, TrainBatch
88
from .datamodule import LightningIRDataModule
9-
from .dataset import DocDataset, QueryDataset, RunDataset, TupleDataset
9+
from .dataset import DocDataset, IRDataset, QueryDataset, RunDataset, TupleDataset
1010

1111
__all__ = [
1212
"DocDataset",
1313
"DocSample",
1414
"IndexBatch",
15+
"IRDataset",
1516
"LightningIRDataModule",
1617
"QueryDataset",
1718
"QuerySample",

0 commit comments

Comments
 (0)