Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions aiocache/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def __init__(
pool_min_size=_NOT_SET,
pool_max_size=None,
create_connection_timeout=None,
ssl=False,
connection_pool_class=None,
connection_pool_kwargs=None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -60,15 +63,28 @@ def __init__(
float(create_connection_timeout) if create_connection_timeout else None
)

connection_pool_kwargs = connection_pool_kwargs or {}

if ssl:
connection_pool_kwargs["connection_class"] = redis.SSLConnection

# NOTE: decoding can't be controlled on API level after switching to
# redis, we need to disable decoding on global/connection level
# (decode_responses=False), because some of the values are saved as
# bytes directly, like pickle serialized values, which may raise an
# exception when decoded with 'utf-8'.
self.client = redis.Redis(host=self.endpoint, port=self.port, db=self.db,
password=self.password, decode_responses=False,
socket_connect_timeout=self.create_connection_timeout,
max_connections=self.pool_max_size)
connection_pool_class = connection_pool_class or redis.ConnectionPool
connection_pool = connection_pool_class(
host=self.endpoint, port=self.port, db=self.db,
password=self.password, decode_responses=False,
socket_connect_timeout=self.create_connection_timeout,
max_connections=self.pool_max_size,
**connection_pool_kwargs
)
self.client = redis.Redis(connection_pool=connection_pool)

# needed for consistency with how Redis creation of connection_pool works
self.client.auto_close_connection_pool = True

async def _get(self, key, encoding="utf-8", _conn=None):
value = await self.client.get(key)
Expand Down
19 changes: 13 additions & 6 deletions tests/ut/backends/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,33 +39,39 @@ class TestRedisBackend:
"max_connections": None,
}

@patch("redis.asyncio.ConnectionPool", name="connection_pool_class", autospec=True)
@patch("redis.asyncio.Redis", name="mock_class", autospec=True)
def test_setup(self, mock_class):
def test_setup(self, mock_class, connection_pool_class):
redis_backend = RedisBackend()
kwargs = self.default_redis_kwargs.copy()
mock_class.assert_called_with(**kwargs)
connection_pool_class.assert_called_with(**kwargs)
mock_class.assert_called_with(connection_pool=connection_pool_class.return_value)

assert redis_backend.endpoint == "127.0.0.1"
assert redis_backend.port == 6379
assert redis_backend.db == 0
assert redis_backend.password is None
assert redis_backend.pool_max_size is None

@patch("redis.asyncio.ConnectionPool", name="connection_pool_class", autospec=True)
@patch("redis.asyncio.Redis", name="mock_class", autospec=True)
def test_setup_override(self, mock_class):
def test_setup_override(self, mock_class, connection_pool_class):
override = {"db": 2, "password": "pass"}
redis_backend = RedisBackend(**override)

kwargs = self.default_redis_kwargs.copy()
kwargs.update(override)
mock_class.assert_called_with(**kwargs)
connection_pool_class.assert_called_with(**kwargs)
mock_class.assert_called_with(connection_pool=connection_pool_class.return_value)

assert redis_backend.endpoint == "127.0.0.1"
assert redis_backend.port == 6379
assert redis_backend.db == 2
assert redis_backend.password == "pass"

@patch("redis.asyncio.ConnectionPool", name="connection_pool_class", autospec=True)
@patch("redis.asyncio.Redis", name="mock_class", autospec=True)
def test_setup_casts(self, mock_class):
def test_setup_casts(self, mock_class, connection_pool_class):
override = {
"db": "2",
"port": "6379",
Expand All @@ -81,7 +87,8 @@ def test_setup_casts(self, mock_class):
"max_connections": 10,
"socket_connect_timeout": 1.5,
})
mock_class.assert_called_with(**kwargs)
connection_pool_class.assert_called_with(**kwargs)
mock_class.assert_called_with(connection_pool=connection_pool_class.return_value)

assert redis_backend.db == 2
assert redis_backend.port == 6379
Expand Down