Skip to content

Commit eb5b612

Browse files
authored
Chronos-2: Only pin_memory when device type is cuda (#431)
1 parent 086e660 commit eb5b612

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

src/chronos/chronos2/pipeline.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,9 @@ def predict(
608608
output_patch_size=self.model_output_patch_size,
609609
mode=DatasetMode.TEST,
610610
)
611-
test_loader = DataLoader(test_dataset, batch_size=None, pin_memory=True, shuffle=False, drop_last=False)
611+
test_loader = DataLoader(
612+
test_dataset, batch_size=None, pin_memory=self.model.device.type == "cuda", shuffle=False, drop_last=False
613+
)
612614

613615
all_predictions: list[torch.Tensor] = []
614616
for batch in test_loader:
@@ -1122,7 +1124,12 @@ def embed(
11221124
mode=DatasetMode.TEST,
11231125
)
11241126
test_loader = DataLoader(
1125-
test_dataset, batch_size=None, num_workers=1, pin_memory=True, shuffle=False, drop_last=False
1127+
test_dataset,
1128+
batch_size=None,
1129+
num_workers=0,
1130+
pin_memory=self.model.device.type == "cuda",
1131+
shuffle=False,
1132+
drop_last=False,
11261133
)
11271134
all_embeds: list[torch.Tensor] = []
11281135
all_loc_scales: list[tuple[torch.Tensor, torch.Tensor]] = []

0 commit comments

Comments
 (0)