Skip to content

Commit

Permalink
fix incremental save for PyTorch 2.6 (#1928)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Feb 3, 2025
1 parent c7af97d commit f6031e3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
9 changes: 5 additions & 4 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,12 @@ def __init__(self, tensor, saver, protocol_version=5):
if reduce_args[0] == torch._utils._rebuild_tensor_v2:
# for Tensors with Python attributes
(a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
assert isinstance(storage, (torch.storage.TypedStorage, torch.storage.UntypedStorage)), "Please check for updates"
storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
else:
(storage, *other_reduce_args) = reduce_args
assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
assert isinstance(storage, (torch.storage.TypedStorage, torch.storage.UntypedStorage)), "Please check for updates"
storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
self.reduce_args = (storage_proxy, *other_reduce_args)

Expand Down Expand Up @@ -245,21 +245,22 @@ def __init__(self, name):
self.zipfile = torch._C.PyTorchFileWriter(str(name))
self.has_saved = False
self.next_key = 0
self.protocol_version = 2

def __enter__(self):
return self

def store_early(self, tensor):
if isinstance(tensor, torch.Tensor):
return SavingProxyForTensor(tensor, self)
return SavingProxyForTensor(tensor, self, protocol_version=self.protocol_version)
raise TypeError(f"can only store tensors early, not {type(tensor)}")

def save(self, obj):
if self.has_saved:
raise RuntimeError("have already saved")
# Write the pickle data for `obj`
data_buf = BytesIO()
pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
pickler = IncrementalPyTorchPickler(self, data_buf, protocol=self.protocol_version)
pickler.dump(obj)
data_value = data_buf.getvalue()
self.zipfile.write_record("data.pkl", data_value, len(data_value))
Expand Down
6 changes: 6 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ def test_incremental_write(tmp_path):
for k, v_expected in sd_expected.items():
v_actual = sd_actual[k]
torch.testing.assert_close(v_expected, v_actual)
sd_actual = torch.load(fn, weights_only=True)
assert sd_actual.keys() == sd_expected.keys()
assert sd_actual["0"].someattr == 1 # requires PyTorch 2.0+
for k, v_expected in sd_expected.items():
v_actual = sd_actual[k]
torch.testing.assert_close(v_expected, v_actual)


@pytest.mark.parametrize("B", (1, 2))
Expand Down

0 comments on commit f6031e3

Please sign in to comment.