diff --git a/aiocache/serializers/serializers.py b/aiocache/serializers/serializers.py index 87fff45fb..0ef56b68f 100644 --- a/aiocache/serializers/serializers.py +++ b/aiocache/serializers/serializers.py @@ -1,5 +1,6 @@ import logging import pickle # noqa: S403 +from abc import ABC, abstractmethod from typing import Any, Optional logger = logging.getLogger(__name__) @@ -20,7 +21,7 @@ _NOT_SET = object() -class BaseSerializer: +class BaseSerializer(ABC): DEFAULT_ENCODING: Optional[str] = "utf-8" @@ -29,12 +30,14 @@ def __init__(self, *args, encoding=_NOT_SET, **kwargs): super().__init__(*args, **kwargs) # TODO(PY38): Positional-only - def dumps(self, value: Any) -> str: - raise NotImplementedError("dumps method must be implemented") + @abstractmethod + def dumps(self, value: Any) -> Any: + """Serialise the value to be stored in the backend.""" # TODO(PY38): Positional-only - def loads(self, value: str) -> Any: - raise NotImplementedError("loads method must be implemented") + @abstractmethod + def loads(self, value: Any) -> Any: + """Decode the value retrieved from the backend.""" class NullSerializer(BaseSerializer): diff --git a/examples/marshmallow_serializer_class.py b/examples/marshmallow_serializer_class.py index 5fd233c3b..f45a2ed75 100644 --- a/examples/marshmallow_serializer_class.py +++ b/examples/marshmallow_serializer_class.py @@ -1,6 +1,7 @@ import random import string import asyncio +from typing import Any from marshmallow import fields, Schema, post_load @@ -21,16 +22,12 @@ def __eq__(self, obj): return self.__dict__ == obj.__dict__ -class MarshmallowSerializer(Schema, BaseSerializer): # type: ignore[misc] +class RandomSchema(Schema): int_type = fields.Integer() str_type = fields.String() dict_type = fields.Dict() list_type = fields.List(fields.Integer()) - # marshmallow Schema class doesn't play nicely with multiple inheritance and won't call - # BaseSerializer.__init__ - encoding = 'utf-8' - @post_load def build_my_type(self, data, **kwargs): return RandomModel(**data) @@ -39,6 +36,18 @@ class Meta: strict = True +class MarshmallowSerializer(BaseSerializer): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.schema = RandomSchema() + + def dumps(self, value: Any) -> str: + return self.schema.dumps(value) + + def loads(self, value: str) -> Any: + return self.schema.loads(value) + + cache = Cache(serializer=MarshmallowSerializer(), namespace="main") diff --git a/tests/acceptance/test_serializers.py b/tests/acceptance/test_serializers.py index bd59f0b0f..694f0a8b6 100644 --- a/tests/acceptance/test_serializers.py +++ b/tests/acceptance/test_serializers.py @@ -1,5 +1,6 @@ import pickle import random +from typing import Any import pytest from marshmallow import Schema, fields, post_load @@ -29,15 +30,8 @@ def __eq__(self, obj): return self.__dict__ == obj.__dict__ -class MyTypeSchema(Schema, BaseSerializer): +class MySchema(Schema): r = fields.Integer() - encoding = "utf-8" - - def dumps(self, *args, **kwargs): - return super().dumps(*args, **kwargs) - - def loads(self, *args, **kwargs): - return super().loads(*args, **kwargs) @post_load def build_my_type(self, data, **kwargs): @@ -47,6 +41,18 @@ class Meta: strict = True +class MyTypeSchema(BaseSerializer): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.schema = MySchema() + + def dumps(self, value: Any) -> str: + return self.schema.dumps(value) + + def loads(self, value: str) -> Any: + return self.schema.loads(value) + + class TestNullSerializer: TYPES = (1, 2.0, "hi", True, ["1", 1], {"key": "value"}, MyType()) diff --git a/tests/ut/test_serializers.py b/tests/ut/test_serializers.py index dd30f256e..33835531b 100644 --- a/tests/ut/test_serializers.py +++ b/tests/ut/test_serializers.py @@ -20,33 +20,18 @@ JSON_TYPES = [1, 2.0, "hi", True, ["1", 1], {"key": "value"}] -class TestBaseSerializer: +class TestNullSerializer: def test_init(self): - serializer = BaseSerializer() + serializer = NullSerializer() + assert isinstance(serializer, BaseSerializer) assert serializer.DEFAULT_ENCODING == "utf-8" assert serializer.encoding == "utf-8" def test_init_encoding(self): - serializer = BaseSerializer(encoding="whatever") + serializer = NullSerializer(encoding="whatever") assert serializer.DEFAULT_ENCODING == "utf-8" assert serializer.encoding == "whatever" - def test_dumps(self): - with pytest.raises(NotImplementedError): - BaseSerializer().dumps("") - - def test_loads(self): - with pytest.raises(NotImplementedError): - BaseSerializer().loads("") - - -class TestNullSerializer: - def test_init(self): - serializer = NullSerializer() - assert isinstance(serializer, BaseSerializer) - assert serializer.DEFAULT_ENCODING == "utf-8" - assert serializer.encoding == "utf-8" - @pytest.mark.parametrize("obj", TYPES) def test_set_types(self, obj): assert NullSerializer().dumps(obj) is obj