Skip to content

Commit 0796b89

Browse files
committed
fix overwriting
1 parent d786809 commit 0796b89

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

lightning_ir/lightning_utils/callbacks.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,8 @@ def __init__(
6464
def setup(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None:
6565
if stage != "test":
6666
raise ValueError("IndexCallback can only be used in test stage")
67-
68-
def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
69-
dataloaders = trainer.test_dataloaders
70-
if dataloaders is None:
71-
raise ValueError("No test_dataloaders found")
72-
datasets = [dataloader.dataset for dataloader in dataloaders]
73-
if not all(isinstance(dataset, DocDataset) for dataset in datasets):
74-
raise ValueError("Expected DocDatasets for indexing")
7567
if not self.overwrite:
68+
datasets = list(trainer.datamodule.inference_datasets)
7669
for dataset in datasets:
7770
index_dir = self.get_index_dir(pl_module, dataset)
7871
if index_dir.exists():
@@ -81,6 +74,14 @@ def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
8174
f"Index dir {index_dir} already exists. Skipping this dataset. Set overwrite=True to overwrite"
8275
)
8376

77+
def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
78+
dataloaders = trainer.test_dataloaders
79+
if dataloaders is None:
80+
raise ValueError("No test_dataloaders found")
81+
datasets = [dataloader.dataset for dataloader in dataloaders]
82+
if not all(isinstance(dataset, DocDataset) for dataset in datasets):
83+
raise ValueError("Expected DocDatasets for indexing")
84+
8485
def get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path:
8586
index_dir = self.index_dir
8687
if index_dir is None:
@@ -112,6 +113,13 @@ def log_to_pg(self, info: Dict[str, Any], trainer: Trainer):
112113
if pg is not None:
113114
pg.set_postfix(info)
114115

116+
def on_test_batch_start(
117+
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
118+
) -> None:
119+
if batch_idx == 0:
120+
self.indexer = self.get_indexer(trainer, pl_module, dataloader_idx)
121+
super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
122+
115123
def on_test_batch_end(
116124
self,
117125
trainer: Trainer,
@@ -121,11 +129,6 @@ def on_test_batch_end(
121129
batch_idx: int,
122130
dataloader_idx: int = 0,
123131
) -> None:
124-
if batch_idx == 0:
125-
if hasattr(self, "indexer"):
126-
self.indexer.save()
127-
self.indexer = self.get_indexer(trainer, pl_module, dataloader_idx)
128-
129132
batch = self.gather(pl_module, batch)
130133
outputs = self.gather(pl_module, outputs)
131134

@@ -140,6 +143,9 @@ def on_test_batch_end(
140143
},
141144
trainer,
142145
)
146+
if batch_idx == trainer.num_test_batches[dataloader_idx] - 1:
147+
assert hasattr(self, "indexer")
148+
self.indexer.save()
143149
return super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
144150

145151
def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None:

0 commit comments

Comments
 (0)