Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions aiocache/backends/memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@


class MemcachedBackend(BaseCache[bytes]):
def __init__(self, endpoint="127.0.0.1", port=11211, pool_size=2, **kwargs):
def __init__(self, host="127.0.0.1", port=11211, pool_size=2, **kwargs):
super().__init__(**kwargs)
self.endpoint = endpoint
self.host = host
self.port = port
self.pool_size = int(pool_size)
self.client = aiomcache.Client(
self.endpoint, self.port, pool_size=self.pool_size
self.host, self.port, pool_size=self.pool_size
)

async def _get(self, key, encoding="utf-8", _conn=None):
Expand Down Expand Up @@ -153,4 +153,4 @@ def parse_uri_path(cls, path):
return {}

def __repr__(self): # pragma: no cover
return "MemcachedCache ({}:{})".format(self.endpoint, self.port)
return "MemcachedCache ({}:{})".format(self.host, self.port)
46 changes: 9 additions & 37 deletions aiocache/backends/redis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import itertools
import warnings
from typing import Any, Callable, Optional, TYPE_CHECKING

import redis.asyncio as redis
Expand Down Expand Up @@ -38,41 +37,19 @@ class RedisBackend(BaseCache[str]):

def __init__(
self,
endpoint="127.0.0.1",
port=6379,
db=0,
password=None,
pool_min_size=_NOT_SET,
pool_max_size=None,
create_connection_timeout=None,
client: redis.Redis,
**kwargs,
):
super().__init__(**kwargs)
if pool_min_size is not _NOT_SET:
warnings.warn(
"Parameter 'pool_min_size' is deprecated since aiocache 0.12",
DeprecationWarning, stacklevel=2
)

self.endpoint = endpoint
self.port = int(port)
self.db = int(db)
self.password = password
# TODO: Remove int() call some time after adding type annotations.
self.pool_max_size = None if pool_max_size is None else int(pool_max_size)
self.create_connection_timeout = (
float(create_connection_timeout) if create_connection_timeout else None
)

# NOTE: decoding can't be controlled on API level after switching to
# redis, we need to disable decoding on global/connection level
# (decode_responses=False), because some of the values are saved as
# bytes directly, like pickle serialized values, which may raise an
# exception when decoded with 'utf-8'.
self.client = redis.Redis(host=self.endpoint, port=self.port, db=self.db,
password=self.password, decode_responses=False,
socket_connect_timeout=self.create_connection_timeout,
max_connections=self.pool_max_size)
if client.connection_pool.connection_kwargs['decode_responses']:
raise ValueError("redis client must be constructed with decode_responses set to False")
self.client = client

async def _get(self, key, encoding="utf-8", _conn=None):
value = await self.client.get(key)
Expand Down Expand Up @@ -175,9 +152,6 @@ async def _raw(self, command, *args, encoding="utf-8", _conn=None, **kwargs):
async def _redlock_release(self, key, value):
return await self._raw("eval", self.RELEASE_SCRIPT, 1, key, value)

async def _close(self, *args, _conn=None, **kwargs):
await self.client.close()

def build_key(self, key: str, namespace: Optional[str] = None) -> str:
return self._str_build_key(key, namespace)

Expand All @@ -196,24 +170,21 @@ class RedisCache(RedisBackend):
the backend. Default is an empty string, "".
:param timeout: int or float in seconds specifying maximum timeout for the operations to last.
By default its 5.
:param endpoint: str with the endpoint to connect to. Default is "127.0.0.1".
:param port: int with the port to connect to. Default is 6379.
:param db: int indicating database to use. Default is 0.
:param password: str indicating password to use. Default is None.
:param pool_max_size: int maximum pool size for the redis connections pool. Default is None.
:param create_connection_timeout: int timeout for the creation of connection. Default is None
:param client: redis.Redis which is an active client for working with redis
"""

NAME = "redis"

def __init__(
self,
client: redis.Redis,
serializer: Optional["BaseSerializer"] = None,
namespace: str = "",
key_builder: Callable[[str, str], str] = lambda k, ns: f"{ns}:{k}" if ns else k,
**kwargs: Any,
):
super().__init__(
client=client,
serializer=serializer or JsonSerializer(),
namespace=namespace,
key_builder=key_builder,
Expand All @@ -237,4 +208,5 @@ def parse_uri_path(cls, path):
return options

def __repr__(self): # pragma: no cover
return "RedisCache ({}:{})".format(self.endpoint, self.port)
connection_kwargs = self.client.connection_pool.connection_kwargs
return "RedisCache ({}:{})".format(connection_kwargs['host'], connection_kwargs['port'])
28 changes: 22 additions & 6 deletions aiocache/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from copy import deepcopy
from typing import Dict

import redis.asyncio as redis

from aiocache import AIOCACHE_CACHES
from aiocache.base import BaseCache
from aiocache.exceptions import InvalidCacheType
Expand All @@ -18,6 +20,7 @@ def _class_from_string(class_path):


def _create_cache(cache, serializer=None, plugins=None, **kwargs):
kwargs = deepcopy(kwargs)
if serializer is not None:
cls = serializer.pop("class")
cls = _class_from_string(cls) if isinstance(cls, str) else cls
Expand All @@ -29,10 +32,17 @@ def _create_cache(cache, serializer=None, plugins=None, **kwargs):
cls = plugin.pop("class")
cls = _class_from_string(cls) if isinstance(cls, str) else cls
plugins_instances.append(cls(**plugin))

cache = _class_from_string(cache) if isinstance(cache, str) else cache
instance = cache(serializer=serializer, plugins=plugins_instances, **kwargs)
return instance
if cache == AIOCACHE_CACHES.get("redis"):
return cache(
serializer=serializer,
plugins=plugins_instances,
namespace=kwargs.pop('namespace', ''),

Check failure

Code scanning / CodeQL

Modification of parameter with default

This expression mutates a [default value](1).
ttl=kwargs.pop('ttl', None),

Check failure

Code scanning / CodeQL

Modification of parameter with default

This expression mutates a [default value](1).
client=redis.Redis(**kwargs)
)
else:
return cache(serializer=serializer, plugins=plugins_instances, **kwargs)


class Cache:
Expand Down Expand Up @@ -112,15 +122,21 @@ def from_url(cls, url):
kwargs.update(cache_class.parse_uri_path(parsed_url.path))

if parsed_url.hostname:
kwargs["endpoint"] = parsed_url.hostname
kwargs["host"] = parsed_url.hostname

if parsed_url.port:
kwargs["port"] = parsed_url.port

if parsed_url.password:
kwargs["password"] = parsed_url.password

return Cache(cache_class, **kwargs)
for arg in ['max_connections', 'socket_connect_timeout']:
if arg in kwargs:
kwargs[arg] = int(kwargs[arg])
if cache_class == cls.REDIS:
return Cache(cache_class, client=redis.Redis(**kwargs))
else:
return Cache(cache_class, **kwargs)


class CacheHandler:
Expand Down Expand Up @@ -214,7 +230,7 @@ def set_config(self, config):
},
'redis_alt': {
'cache': "aiocache.RedisCache",
'endpoint': "127.0.0.10",
'host': "127.0.0.10",
'port': 6378,
'serializer': {
'class': "aiocache.serializers.PickleSerializer"
Expand Down
14 changes: 8 additions & 6 deletions examples/cached_alias_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio

import redis.asyncio as redis

from aiocache import caches, Cache
from aiocache.serializers import StringSerializer, PickleSerializer

Expand All @@ -12,9 +14,9 @@
},
'redis_alt': {
'cache': "aiocache.RedisCache",
'endpoint': "127.0.0.1",
'host': "127.0.0.1",
'port': 6379,
'timeout': 1,
'socket_connect_timeout': 1,
'serializer': {
'class': "aiocache.serializers.PickleSerializer"
},
Expand Down Expand Up @@ -45,17 +47,17 @@ async def alt_cache():
assert isinstance(cache, Cache.REDIS)
assert isinstance(cache.serializer, PickleSerializer)
assert len(cache.plugins) == 2
assert cache.endpoint == "127.0.0.1"
assert cache.timeout == 1
assert cache.port == 6379
assert cache.client.connection_pool.connection_kwargs['host'] == "127.0.0.1"
assert cache.client.connection_pool.connection_kwargs['socket_connect_timeout'] == 1
assert cache.client.connection_pool.connection_kwargs['port'] == 6379
await cache.close()


async def test_alias():
await default_cache()
await alt_cache()

cache = Cache(Cache.REDIS)
cache = Cache(Cache.REDIS, client=redis.Redis() )
await cache.delete("key")
await cache.close()

Expand Down
5 changes: 3 additions & 2 deletions examples/cached_decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio

from collections import namedtuple
import redis.asyncio as redis

from aiocache import cached, Cache
from aiocache.serializers import PickleSerializer
Expand All @@ -10,13 +11,13 @@

@cached(
ttl=10, cache=Cache.REDIS, key_builder=lambda *args, **kw: "key",
serializer=PickleSerializer(), port=6379, namespace="main")
serializer=PickleSerializer(), namespace="main", client = redis.Redis())
async def cached_call():
return Result("content", 200)


async def test_cached():
async with Cache(Cache.REDIS, endpoint="127.0.0.1", port=6379, namespace="main") as cache:
async with Cache(Cache.REDIS, namespace="main", client=redis.Redis()) as cache:
await cached_call()
exists = await cache.exists("key")
assert exists is True
Expand Down
9 changes: 6 additions & 3 deletions examples/multicached_decorator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio

import redis.asyncio as redis

from aiocache import multi_cached, Cache

DICT = {
Expand All @@ -9,18 +11,19 @@
'd': "W"
}

cache = Cache(Cache.REDIS, namespace="main", client=redis.Redis())


@multi_cached("ids", cache=Cache.REDIS, namespace="main")
@multi_cached("ids", cache=Cache.REDIS, namespace="main", client=cache.client)
async def multi_cached_ids(ids=None):
return {id_: DICT[id_] for id_ in ids}


@multi_cached("keys", cache=Cache.REDIS, namespace="main")
@multi_cached("keys", cache=Cache.REDIS, namespace="main", client=cache.client)
async def multi_cached_keys(keys=None):
return {id_: DICT[id_] for id_ in keys}


cache = Cache(Cache.REDIS, endpoint="127.0.0.1", port=6379, namespace="main")


async def test_multi_cached():
Expand Down
5 changes: 3 additions & 2 deletions examples/optimistic_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import logging
import random

import redis.asyncio as redis

from aiocache import Cache
from aiocache.lock import OptimisticLock, OptimisticLockError


logger = logging.getLogger(__name__)
cache = Cache(Cache.REDIS, endpoint='127.0.0.1', port=6379, namespace='main')
cache = Cache(Cache.REDIS, namespace="main", client=redis.Redis())


async def expensive_function():
Expand Down
6 changes: 4 additions & 2 deletions examples/python_object.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio

from collections import namedtuple
import redis.asyncio as redis


from aiocache import Cache
from aiocache.serializers import PickleSerializer


MyObject = namedtuple("MyObject", ["x", "y"])
cache = Cache(Cache.REDIS, serializer=PickleSerializer(), namespace="main")
cache = Cache(Cache.REDIS, serializer=PickleSerializer(), namespace="main", client=redis.Redis())


async def complex_object():
Expand Down
5 changes: 3 additions & 2 deletions examples/redlock.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
import logging

import redis.asyncio as redis

from aiocache import Cache
from aiocache.lock import RedLock


logger = logging.getLogger(__name__)
cache = Cache(Cache.REDIS, endpoint='127.0.0.1', port=6379, namespace='main')
cache = Cache(Cache.REDIS, namespace='main', client=redis.Redis())


async def expensive_function():
Expand Down
4 changes: 3 additions & 1 deletion examples/serializer_class.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import zlib

import redis.asyncio as redis

from aiocache import Cache
from aiocache.serializers import BaseSerializer

Expand All @@ -25,7 +27,7 @@ def loads(self, value):
return decompressed


cache = Cache(Cache.REDIS, serializer=CompressionSerializer(), namespace="main")
cache = Cache(Cache.REDIS, serializer=CompressionSerializer(), namespace="main", client=redis.Redis())


async def serializer():
Expand Down
4 changes: 3 additions & 1 deletion examples/serializer_function.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import json

import redis.asyncio as redis

from marshmallow import Schema, fields, post_load
from aiocache import Cache

Expand Down Expand Up @@ -28,7 +30,7 @@ def loads(value):
return MyTypeSchema().loads(value)


cache = Cache(Cache.REDIS, namespace="main")
cache = Cache(Cache.REDIS, namespace="main", client=redis.Redis())


async def serializer_function():
Expand Down
3 changes: 2 additions & 1 deletion examples/simple_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from aiocache import Cache

import redis.asyncio as redis

cache = Cache(Cache.REDIS, endpoint="127.0.0.1", port=6379, namespace="main")
cache = Cache(Cache.REDIS, namespace="main" , client=redis.Redis() )


async def redis():
Expand Down
4 changes: 2 additions & 2 deletions tests/acceptance/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def reset_caches():


@pytest.fixture
async def redis_cache():
async with Cache(Cache.REDIS, namespace="test") as cache:
async def redis_cache(redis_client):
async with Cache(Cache.REDIS, namespace="test", client=redis_client) as cache:
yield cache
await asyncio.gather(*(cache.delete(k) for k in (*Keys, KEY_LOCK)))

Expand Down
Loading