Skip to content

Commit e192068

Browse files
committed
removed .to(device) for checkpoint loading
1 parent 4259de2 commit e192068

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

olmo/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,7 +1768,7 @@ def from_checkpoint(
17681768
)
17691769

17701770
model_config.init_device = device
1771-
model = OLMo(model_config).to(device)
1771+
model = OLMo(model_config)
17721772
load_model_and_optim_state(checkpoint_dir, model)
17731773
else:
17741774
# train_config.sharded_checkpointer == ShardedCheckpointerType.torch_new
@@ -1777,7 +1777,7 @@ def from_checkpoint(
17771777
# Initialize model on target device. In this case the state dict is loaded in-place
17781778
# so it's not necessary to start on CPU if the target device is a GPU.
17791779
model_config.init_device = device
1780-
model = OLMo(model_config).to(device)
1780+
model = OLMo(model_config)
17811781

17821782
# Load state dict in place.
17831783
load_model_state(checkpoint_dir, model)

0 commit comments

Comments
 (0)