14
14
from lightning .pytorch .trainer .states import RunningStage
15
15
from transformers import BatchEncoding
16
16
17
- from ..data import RankBatch , RunDataset , SearchBatch , TrainBatch
18
- from ..data .dataset import IRDataset
17
+ from ..data import IRDataset , RankBatch , RunDataset , SearchBatch , TrainBatch
19
18
from ..loss .loss import InBatchLossFunction , LossFunction
20
19
from .config import LightningIRConfig
21
20
from .model import LightningIRModel , LightningIROutput
@@ -235,7 +234,8 @@ def validation_step(
235
234
if self .evaluation_metrics is None :
236
235
return output
237
236
238
- dataset_id = self .get_dataset_id (dataloader_idx )
237
+ dataset = self .get_dataset (dataloader_idx )
238
+ dataset_id = str (dataloader_idx ) if dataset is None else self .get_dataset_id (dataset )
239
239
metrics = self .validate (output , batch )
240
240
for key , value in metrics .items ():
241
241
key = f"{ dataset_id } /{ key } "
@@ -290,7 +290,7 @@ def get_dataset(self, dataloader_idx: int) -> IRDataset | None:
290
290
dataloaders = [dataloaders ]
291
291
return dataloaders [dataloader_idx ].dataset
292
292
293
- def get_dataset_id (self , dataloader_idx : int ) -> str :
293
+ def get_dataset_id (self , dataset : IRDataset ) -> str :
294
294
"""Gets the dataset id from the dataloader index for logging.
295
295
296
296
.. _ir-datasets: https://ir-datasets.com/
@@ -300,9 +300,6 @@ def get_dataset_id(self, dataloader_idx: int) -> str:
300
300
:return: path to run file, ir-datasets_ dataset id, or dataloader index
301
301
:rtype: str
302
302
"""
303
- dataset = self .get_dataset (dataloader_idx )
304
- if dataset is None :
305
- return str (dataloader_idx )
306
303
if isinstance (dataset , RunDataset ) and dataset .run_path is not None :
307
304
dataset_id = dataset .run_path .name
308
305
else :
@@ -420,7 +417,15 @@ def on_validation_end(self) -> None:
420
417
df .columns .name = None
421
418
422
419
# bring into correct order when skipping inference datasets
423
- dataset_ids = [self .get_dataset_id (i ) for i in range (df .shape [0 ])]
420
+ datamodule = getattr (self .trainer , "datamodule" , None )
421
+ if datamodule is not None and hasattr (datamodule , "inference_datasets" ):
422
+ inference_datasets = datamodule .inference_datasets
423
+ if len (inference_datasets ) != df .shape [0 ]:
424
+ raise ValueError (
425
+ "Number of inference datasets does not match number of dataloaders. "
426
+ "Check if the dataloaders are correctly configured."
427
+ )
428
+ dataset_ids = [self .get_dataset_id (dataset ) for dataset in inference_datasets ]
424
429
df = df .reindex (dataset_ids )
425
430
426
431
trainer .print (df )
0 commit comments