query_dm = WrapperDataModule(self.query_ds, batch_size=1) track_dm = WrapperDataModule(self.track_ds, batch_size=1) change to batch_size=self.batchsize