Skip to content
13 changes: 8 additions & 5 deletions aiocache/serializers/serializers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import pickle # noqa: S403
from abc import ABC, abstractmethod
from typing import Any, Optional

logger = logging.getLogger(__name__)
Expand All @@ -20,7 +21,7 @@
_NOT_SET = object()


class BaseSerializer:
class BaseSerializer(ABC):

DEFAULT_ENCODING: Optional[str] = "utf-8"

Expand All @@ -29,12 +30,14 @@ def __init__(self, *args, encoding=_NOT_SET, **kwargs):
super().__init__(*args, **kwargs)

# TODO(PY38): Positional-only
def dumps(self, value: Any) -> str:
raise NotImplementedError("dumps method must be implemented")
@abstractmethod
def dumps(self, value: Any) -> Any:
"""Serialise the value to be stored in the backend."""

# TODO(PY38): Positional-only
def loads(self, value: str) -> Any:
raise NotImplementedError("loads method must be implemented")
@abstractmethod
def loads(self, value: Any) -> Any:
"""Decode the value retrieved from the backend."""


class NullSerializer(BaseSerializer):
Expand Down
19 changes: 14 additions & 5 deletions examples/marshmallow_serializer_class.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
import string
import asyncio
from typing import Any

from marshmallow import fields, Schema, post_load

Expand All @@ -21,16 +22,12 @@ def __eq__(self, obj):
return self.__dict__ == obj.__dict__


class MarshmallowSerializer(Schema, BaseSerializer): # type: ignore[misc]
class RandomSchema(Schema):
int_type = fields.Integer()
str_type = fields.String()
dict_type = fields.Dict()
list_type = fields.List(fields.Integer())

# marshmallow Schema class doesn't play nicely with multiple inheritance and won't call
# BaseSerializer.__init__
encoding = 'utf-8'

@post_load
def build_my_type(self, data, **kwargs):
return RandomModel(**data)
Expand All @@ -39,6 +36,18 @@ class Meta:
strict = True


class MarshmallowSerializer(BaseSerializer):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.schema = RandomSchema()

def dumps(self, value: Any) -> str:
return self.schema.dumps(value)

def loads(self, value: str) -> Any:
return self.schema.loads(value)


cache = Cache(serializer=MarshmallowSerializer(), namespace="main")


Expand Down
22 changes: 14 additions & 8 deletions tests/acceptance/test_serializers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle
import random
from typing import Any

import pytest
from marshmallow import Schema, fields, post_load
Expand Down Expand Up @@ -29,15 +30,8 @@ def __eq__(self, obj):
return self.__dict__ == obj.__dict__


class MyTypeSchema(Schema, BaseSerializer):
class MySchema(Schema):
r = fields.Integer()
encoding = "utf-8"

def dumps(self, *args, **kwargs):
return super().dumps(*args, **kwargs)

def loads(self, *args, **kwargs):
return super().loads(*args, **kwargs)

@post_load
def build_my_type(self, data, **kwargs):
Expand All @@ -47,6 +41,18 @@ class Meta:
strict = True


class MyTypeSchema(BaseSerializer):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.schema = MySchema()

def dumps(self, value: Any) -> str:
return self.schema.dumps(value)

def loads(self, value: str) -> Any:
return self.schema.loads(value)


class TestNullSerializer:
TYPES = (1, 2.0, "hi", True, ["1", 1], {"key": "value"}, MyType())

Expand Down
23 changes: 4 additions & 19 deletions tests/ut/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,18 @@
JSON_TYPES = [1, 2.0, "hi", True, ["1", 1], {"key": "value"}]


class TestBaseSerializer:
class TestNullSerializer:
def test_init(self):
serializer = BaseSerializer()
serializer = NullSerializer()
assert isinstance(serializer, BaseSerializer)
assert serializer.DEFAULT_ENCODING == "utf-8"
assert serializer.encoding == "utf-8"

def test_init_encoding(self):
serializer = BaseSerializer(encoding="whatever")
serializer = NullSerializer(encoding="whatever")
assert serializer.DEFAULT_ENCODING == "utf-8"
assert serializer.encoding == "whatever"

def test_dumps(self):
with pytest.raises(NotImplementedError):
BaseSerializer().dumps("")

def test_loads(self):
with pytest.raises(NotImplementedError):
BaseSerializer().loads("")


class TestNullSerializer:
def test_init(self):
serializer = NullSerializer()
assert isinstance(serializer, BaseSerializer)
assert serializer.DEFAULT_ENCODING == "utf-8"
assert serializer.encoding == "utf-8"

@pytest.mark.parametrize("obj", TYPES)
def test_set_types(self, obj):
assert NullSerializer().dumps(obj) is obj
Expand Down