-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Description
Description
Previous RFC: #26229
Tensor serialization operations frequently appear in machine learning scenarios. The pickle5 zero-copy serialization proposal offers an opportunity to accelerate read-only tensor serialization. Currently, the mainstream reinforcement learning framework verl and LLM inference framework vllm both support this feature.
@tianyi-ge and I aim to implement an API based on annotations or decorators to provide serialization performance optimization for read-only tensors. We envision two solutions:
- Solution 1: For discrete tensors, users explicitly specify a list of read-only tensor arguments. Zero-copy serialization is then applied only to these designated read-only tensors.
@ray.remote(read_only_args=['a', 'b'])
def transfer_big_tensor(a, b, c):
print(a, b) # read-only, ideally zero-copy
c += 1 # mutable
@edoakes and @jjyao noted limitations with this approach. For example, it's highly likely that users would want to pass read-only objects nested within other objects, which would be difficult or impossible to express using this method.
- Solution 2: For scenarios where read-only tensors are nested within complex data structures, the following steps are required.
- Define a wrapper class for read-only tensors.
class ReadOnlyTensor:
def __init__(self, tensor: torch.Tensor):
if not isinstance(tensor, torch.Tensor):
raise TypeError("Input must be a torch.Tensor")
self._tensor = tensor.detach().cpu().contiguous()
self._numpy_view = None
def to_numpy(self):
if self._numpy_view is None:
arr = self._tensor.numpy()
arr.flags.writeable = False
self._numpy_view = arr
return self._numpy_view
def to_tensor(self):
return self._tensor
- Extract read-only tensors recursively and apply zero-copy serialization.
# obj contains mutable tensors and read-only tensors
# serialize mutable part to bytes, offload read-only tensors as pickle5 OOB buffers
mutable_bytes, readonly_buffers = zero_copy_serialize(obj)
- Reconstruct the full data structure during deserialization.
# deserialize mutable part from bytes, attach read-only tensors from pickle5 OOB buffers
restored_obj = zero_copy_deserialize(type(obj), mutable_bytes, readonly_buffers)
This second solution is more general but presents a significantly greater implementation complexity.
self-assign this issue.
Use case
Any scenario requiring the serialization of tensors intended to be read-only.
Example usage:
import torch
import ray
from ray._common import ReadOnlyTensor
@ray.remote(try_fast_serial=True)
class MixedTensorObject:
"""Example class containing both read-only and mutable tensors"""
readonly_part: ReadOnlyTensor # read-only tensor
mutable_part: torch.Tensor # mutable tensor
def __init__(self, readonly_part: ReadOnlyTensor, mutable_part: torch.Tensor):
self.readonly_part = readonly_part
self.mutable_part = mutable_part
def sum_readonly_part(self):
return self.readonly_part.sum()
def increment_mutable_part(self):
self.mutable_part += 1
return self.mutable_part
obj = MixedTensorObject.remote(large_tensor, small_tensor)
assert ray.get(obj.sum_readonly_part()) == large_tensor.sum()
assert ray.get(obj.increment_mutable_part()) == small_tensor + 1