@@ -64,15 +64,8 @@ def __init__(
64
64
def setup (self , trainer : Trainer , pl_module : BiEncoderModule , stage : str ) -> None :
65
65
if stage != "test" :
66
66
raise ValueError ("IndexCallback can only be used in test stage" )
67
-
68
- def on_test_start (self , trainer : Trainer , pl_module : BiEncoderModule ) -> None :
69
- dataloaders = trainer .test_dataloaders
70
- if dataloaders is None :
71
- raise ValueError ("No test_dataloaders found" )
72
- datasets = [dataloader .dataset for dataloader in dataloaders ]
73
- if not all (isinstance (dataset , DocDataset ) for dataset in datasets ):
74
- raise ValueError ("Expected DocDatasets for indexing" )
75
67
if not self .overwrite :
68
+ datasets = list (trainer .datamodule .inference_datasets )
76
69
for dataset in datasets :
77
70
index_dir = self .get_index_dir (pl_module , dataset )
78
71
if index_dir .exists ():
@@ -81,6 +74,14 @@ def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
81
74
f"Index dir { index_dir } already exists. Skipping this dataset. Set overwrite=True to overwrite"
82
75
)
83
76
77
+ def on_test_start (self , trainer : Trainer , pl_module : BiEncoderModule ) -> None :
78
+ dataloaders = trainer .test_dataloaders
79
+ if dataloaders is None :
80
+ raise ValueError ("No test_dataloaders found" )
81
+ datasets = [dataloader .dataset for dataloader in dataloaders ]
82
+ if not all (isinstance (dataset , DocDataset ) for dataset in datasets ):
83
+ raise ValueError ("Expected DocDatasets for indexing" )
84
+
84
85
def get_index_dir (self , pl_module : BiEncoderModule , dataset : DocDataset ) -> Path :
85
86
index_dir = self .index_dir
86
87
if index_dir is None :
@@ -112,6 +113,13 @@ def log_to_pg(self, info: Dict[str, Any], trainer: Trainer):
112
113
if pg is not None :
113
114
pg .set_postfix (info )
114
115
116
+ def on_test_batch_start (
117
+ self , trainer : Trainer , pl_module : LightningModule , batch : Any , batch_idx : int , dataloader_idx : int = 0
118
+ ) -> None :
119
+ if batch_idx == 0 :
120
+ self .indexer = self .get_indexer (trainer , pl_module , dataloader_idx )
121
+ super ().on_test_batch_start (trainer , pl_module , batch , batch_idx , dataloader_idx )
122
+
115
123
def on_test_batch_end (
116
124
self ,
117
125
trainer : Trainer ,
@@ -121,11 +129,6 @@ def on_test_batch_end(
121
129
batch_idx : int ,
122
130
dataloader_idx : int = 0 ,
123
131
) -> None :
124
- if batch_idx == 0 :
125
- if hasattr (self , "indexer" ):
126
- self .indexer .save ()
127
- self .indexer = self .get_indexer (trainer , pl_module , dataloader_idx )
128
-
129
132
batch = self .gather (pl_module , batch )
130
133
outputs = self .gather (pl_module , outputs )
131
134
@@ -140,6 +143,9 @@ def on_test_batch_end(
140
143
},
141
144
trainer ,
142
145
)
146
+ if batch_idx == trainer .num_test_batches [dataloader_idx ] - 1 :
147
+ assert hasattr (self , "indexer" )
148
+ self .indexer .save ()
143
149
return super ().on_test_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
144
150
145
151
def on_test_end (self , trainer : Trainer , pl_module : LightningModule ) -> None :
0 commit comments