diff --git a/python/xoscar/serialization/core.pyx b/python/xoscar/serialization/core.pyx index 745276f9..290c9af6 100644 --- a/python/xoscar/serialization/core.pyx +++ b/python/xoscar/serialization/core.pyx @@ -51,7 +51,7 @@ except (ImportError, AttributeError): from .._utils import NamedType from .._utils cimport TypeDispatcher -from .pyfury import get_fury +from .pyfury import get_fury, register_class_to_fury BUFFER_PICKLE_PROTOCOL = max(pickle.DEFAULT_PROTOCOL, 5) cdef bint HAS_PICKLE_BUFFER = pickle.HIGHEST_PROTOCOL >= 5 @@ -218,47 +218,64 @@ def buffered(func): def pickle_buffers(obj): cdef list buffers = [None] - fury = get_fury() - if fury is not None: + if HAS_PICKLE_BUFFER: + def buffer_cb(x): - try: - buffers.append(memoryview(x)) - except TypeError: - buffers.append(x.to_buffer()) + x = x.raw() + if x.ndim > 1: + # ravel n-d memoryview + x = x.cast(x.format) + buffers.append(memoryview(x)) - buffers[0] = b"__fury__" - buffers.append(None) - buffers[1] = fury.serialize( + buffers[0] = cloudpickle.dumps( obj, buffer_callback=buffer_cb, + protocol=BUFFER_PICKLE_PROTOCOL, ) - else: - if HAS_PICKLE_BUFFER: - def buffer_cb(x): - x = x.raw() - if x.ndim > 1: - # ravel n-d memoryview - x = x.cast(x.format) - buffers.append(memoryview(x)) - - buffers[0] = cloudpickle.dumps( - obj, - buffer_callback=buffer_cb, - protocol=BUFFER_PICKLE_PROTOCOL, - ) - else: - buffers[0] = cloudpickle.dumps(obj) + else: # pragma: no cover + buffers[0] = cloudpickle.dumps(obj) return buffers def unpickle_buffers(list buffers): - if buffers[0] == b"__fury__": - fury = get_fury() - if fury is None: - raise Exception("fury is not installed.") - result = fury.deserialize(buffers[1], buffers[2:]) - else: - result = cloudpickle.loads(buffers[0], buffers=buffers[1:]) + result = cloudpickle.loads(buffers[0], buffers=buffers[1:]) + + # as pandas prior to 1.1.0 use _data instead of _mgr to hold BlockManager, + # deserializing from high versions may produce mal-functioned pandas objects, + # thus the patch is needed + if _PANDAS_HAS_MGR: + return result + else: # pragma: no cover + if hasattr(result, "_mgr") and isinstance(result, (pd.DataFrame, pd.Series)): + result._data = getattr(result, "_mgr") + delattr(result, "_mgr") + return result + + +def fury_serialize_buffers(obj): + cdef list buffers = [None] + + fury = get_fury() + if fury is None: + raise Exception(f"fury is not installed.") + def buffer_cb(x): + try: + buffers.append(memoryview(x)) + except TypeError: + buffers.append(x.to_buffer()) + + buffers[0] = fury.serialize( + obj, + buffer_callback=buffer_cb, + ) + return buffers + + +def fury_deserialize_buffers(list buffers): + fury = get_fury() + if fury is None: + raise Exception("fury is not installed.") + result = fury.deserialize(buffers[0], buffers[1:]) # as pandas prior to 1.1.0 use _data instead of _mgr to hold BlockManager, # deserializing from high versions may produce mal-functioned pandas objects, @@ -288,6 +305,28 @@ cdef class PickleSerializer(Serializer): return unpickle_buffers(subs) +cdef class FurySerializer(Serializer): + serializer_id = 100 + + cpdef serial(self, obj: Any, dict context): + cdef uint64_t obj_id + obj_id = _fast_id(obj) + if obj_id in context: + return Placeholder(obj_id) + context[obj_id] = obj + + return (), fury_serialize_buffers(obj), True + + cpdef deserial(self, tuple serialized, dict context, list subs): + return fury_deserialize_buffers(subs) + + @classmethod + def register(cls, obj_type, name=None): + if register_class_to_fury(obj_type): + # Only register type to FurySerializer if fury is enabled. + super().register(obj_type, name) + + cdef set _primitive_types = { type(None), bool, diff --git a/python/xoscar/serialization/pyfury.py b/python/xoscar/serialization/pyfury.py index c192f554..6c749fbe 100644 --- a/python/xoscar/serialization/pyfury.py +++ b/python/xoscar/serialization/pyfury.py @@ -3,15 +3,17 @@ _fury = threading.local() _fury_not_installed = object() -_register_class_list = set() +_register_classes = set() -def register_classes(*args): +def register_class_to_fury(obj_type): instance = get_fury() if instance is not None: - _register_class_list.update(args) - for c in _register_class_list: + _register_classes.add(obj_type) + for c in _register_classes: instance.register_class(c) + return True + return False def get_fury(): @@ -26,9 +28,9 @@ def get_fury(): import pyfury _fury.instance = instance = pyfury.Fury( - language=pyfury.Language.PYTHON, require_class_registration=False + language=pyfury.Language.PYTHON, ref_tracking=True ) - for c in _register_class_list: # pragma: no cover + for c in _register_classes: # pragma: no cover instance.register_class(c) print("pyfury is enabled.") except ImportError: # pragma: no cover diff --git a/python/xoscar/serialization/tests/test_serial.py b/python/xoscar/serialization/tests/test_serial.py index cc293640..b8d13a5e 100644 --- a/python/xoscar/serialization/tests/test_serial.py +++ b/python/xoscar/serialization/tests/test_serial.py @@ -184,7 +184,7 @@ def test_arrow(): @pytest.mark.skipif(pyfury is None, reason="need pyfury to run the cases") def test_arrow_fury(): os.environ["USE_FURY"] = "1" - from ..pyfury import register_classes + from ..core import FurySerializer try: test_df = pd.DataFrame( @@ -194,7 +194,8 @@ def test_arrow_fury(): "c": np.random.randint(0, 100, size=(1000,)), } ) - register_classes(pa.RecordBatch, pa.Table) + FurySerializer.register(pa.RecordBatch) + FurySerializer.register(pa.Table) test_vals = [ pa.RecordBatch.from_pandas(test_df), pa.Table.from_pandas(test_df),