@@ -168,12 +168,12 @@ def __init__(self, tensor, saver, protocol_version=5):
168
168
if reduce_args [0 ] == torch ._utils ._rebuild_tensor_v2 :
169
169
# for Tensors with Python attributes
170
170
(a0 , a1 , (storage , * a2_other ), * other_reduce_args ) = reduce_args
171
- assert isinstance (storage , torch .storage .TypedStorage ), "Please check for updates"
171
+ assert isinstance (storage , ( torch .storage .TypedStorage , torch . storage . UntypedStorage ) ), "Please check for updates"
172
172
storage_proxy = SavingProxyForStorage (storage , saver , protocol_version = protocol_version )
173
173
self .reduce_args = (a0 , a1 , (storage_proxy , * a2_other ), * other_reduce_args )
174
174
else :
175
175
(storage , * other_reduce_args ) = reduce_args
176
- assert isinstance (storage , torch .storage .TypedStorage ), "Please check for updates"
176
+ assert isinstance (storage , ( torch .storage .TypedStorage , torch . storage . UntypedStorage ) ), "Please check for updates"
177
177
storage_proxy = SavingProxyForStorage (storage , saver , protocol_version = protocol_version )
178
178
self .reduce_args = (storage_proxy , * other_reduce_args )
179
179
@@ -245,21 +245,22 @@ def __init__(self, name):
245
245
self .zipfile = torch ._C .PyTorchFileWriter (str (name ))
246
246
self .has_saved = False
247
247
self .next_key = 0
248
+ self .protocol_version = 2
248
249
249
250
def __enter__ (self ):
250
251
return self
251
252
252
253
def store_early (self , tensor ):
253
254
if isinstance (tensor , torch .Tensor ):
254
- return SavingProxyForTensor (tensor , self )
255
+ return SavingProxyForTensor (tensor , self , protocol_version = self . protocol_version )
255
256
raise TypeError (f"can only store tensors early, not { type (tensor )} " )
256
257
257
258
def save (self , obj ):
258
259
if self .has_saved :
259
260
raise RuntimeError ("have already saved" )
260
261
# Write the pickle data for `obj`
261
262
data_buf = BytesIO ()
262
- pickler = IncrementalPyTorchPickler (self , data_buf , protocol = 5 )
263
+ pickler = IncrementalPyTorchPickler (self , data_buf , protocol = self . protocol_version )
263
264
pickler .dump (obj )
264
265
data_value = data_buf .getvalue ()
265
266
self .zipfile .write_record ("data.pkl" , data_value , len (data_value ))
0 commit comments