Skip to content

Commit f6031e3

Browse files
authored
fix incremental save for PyTorch 2.6 (#1928)
1 parent c7af97d commit f6031e3

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

litgpt/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,12 @@ def __init__(self, tensor, saver, protocol_version=5):
168168
if reduce_args[0] == torch._utils._rebuild_tensor_v2:
169169
# for Tensors with Python attributes
170170
(a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
171-
assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
171+
assert isinstance(storage, (torch.storage.TypedStorage, torch.storage.UntypedStorage)), "Please check for updates"
172172
storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
173173
self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
174174
else:
175175
(storage, *other_reduce_args) = reduce_args
176-
assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
176+
assert isinstance(storage, (torch.storage.TypedStorage, torch.storage.UntypedStorage)), "Please check for updates"
177177
storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
178178
self.reduce_args = (storage_proxy, *other_reduce_args)
179179

@@ -245,21 +245,22 @@ def __init__(self, name):
245245
self.zipfile = torch._C.PyTorchFileWriter(str(name))
246246
self.has_saved = False
247247
self.next_key = 0
248+
self.protocol_version = 2
248249

249250
def __enter__(self):
250251
return self
251252

252253
def store_early(self, tensor):
253254
if isinstance(tensor, torch.Tensor):
254-
return SavingProxyForTensor(tensor, self)
255+
return SavingProxyForTensor(tensor, self, protocol_version=self.protocol_version)
255256
raise TypeError(f"can only store tensors early, not {type(tensor)}")
256257

257258
def save(self, obj):
258259
if self.has_saved:
259260
raise RuntimeError("have already saved")
260261
# Write the pickle data for `obj`
261262
data_buf = BytesIO()
262-
pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
263+
pickler = IncrementalPyTorchPickler(self, data_buf, protocol=self.protocol_version)
263264
pickler.dump(obj)
264265
data_value = data_buf.getvalue()
265266
self.zipfile.write_record("data.pkl", data_value, len(data_value))

tests/test_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ def test_incremental_write(tmp_path):
122122
for k, v_expected in sd_expected.items():
123123
v_actual = sd_actual[k]
124124
torch.testing.assert_close(v_expected, v_actual)
125+
sd_actual = torch.load(fn, weights_only=True)
126+
assert sd_actual.keys() == sd_expected.keys()
127+
assert sd_actual["0"].someattr == 1 # requires PyTorch 2.0+
128+
for k, v_expected in sd_expected.items():
129+
v_actual = sd_actual[k]
130+
torch.testing.assert_close(v_expected, v_actual)
125131

126132

127133
@pytest.mark.parametrize("B", (1, 2))

0 commit comments

Comments
 (0)