Skip to content

Commit 71ff0d6

Browse files
authored
Chronos-2: Add option to disable DataParallel (#434)
*Issue #, if available:* *Description of changes:* By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent eb5b612 commit 71ff0d6

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/chronos/chronos2/pipeline.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def fit(
114114
finetuned_ckpt_name: str = "finetuned-ckpt",
115115
callbacks: list["TrainerCallback"] | None = None,
116116
remove_printer_callback: bool = False,
117+
disable_data_parallel: bool = True,
117118
**extra_trainer_kwargs,
118119
) -> "Chronos2Pipeline":
119120
"""
@@ -158,6 +159,8 @@ def fit(
158159
A list of `TrainerCallback`s which will be forwarded to the HuggingFace `Trainer`
159160
remove_printer_callback
160161
If True, all instances of `PrinterCallback` are removed from callbacks
162+
disable_data_parallel
163+
If True, ensures that DataParallel is disabled and training happens on a single GPU
161164
**extra_trainer_kwargs
162165
Extra kwargs are directly forwarded to `TrainingArguments`
163166
@@ -319,6 +322,11 @@ def fit(
319322

320323
training_args = TrainingArguments(**training_kwargs)
321324

325+
if disable_data_parallel and not use_cpu:
326+
# This is a hack to disable the default `transformers` behavior of using DataParallel
327+
training_args._n_gpu = 1
328+
assert training_args.n_gpu == 1 # Ensure that the hack worked
329+
322330
trainer = Chronos2Trainer(
323331
model=model,
324332
args=training_args,

0 commit comments

Comments
 (0)