Skip to content

Commit

Permalink
Replace redis_key_func with prefix argument
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-oleshkevich committed Aug 15, 2022
1 parent 187c2f7 commit f36677a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 28 deletions.
41 changes: 20 additions & 21 deletions starsessions/backends/redis.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,44 @@
import aioredis
import functools
import typing

from starsessions.backends.base import SessionBackend
from starsessions.exceptions import ImproperlyConfigured


def prefix_factory(prefix: str, key: str) -> str:
return prefix + key


class RedisBackend(SessionBackend):
"""Stores session data in a Redis server."""

def __init__(
self,
url: typing.Optional[str] = None,
connection: typing.Optional[aioredis.Redis] = None,
redis_key_func: typing.Optional[typing.Callable[[str], str]] = None,
prefix: typing.Union[typing.Callable[[str], str], str] = '',
) -> None:
"""
Initializes redis session backend.
Initializes redis session backend. Either `url` or `connection` required. To namespace keys in Redis use
`prefix` argument. It can be a string or callable that accepts a single string argument and returns new Redis
key as string.
Args:
url (str, optional): Redis URL. Defaults to None.
connection (aioredis.Redis, optional): aioredis connection. Defaults to None.
redis_key_func (typing.Callable[[str], str], optional): Customize redis key name. Defaults to None.
:param url: Redis URL. Defaults to None.
:param connection: aioredis connection. Defaults to None
:param prefix: Redis key name prefix or factory.
"""
if not (url or connection):
raise ImproperlyConfigured("Either 'url' or 'connection' arguments must be provided.")

self._connection: aioredis.Redis = connection or aioredis.from_url(url) # type: ignore[no-untyped-call]

if redis_key_func and not callable(redis_key_func):
raise ImproperlyConfigured("The redis_key_func needs to be a callable that takes a single string argument.")

self._redis_key_func = redis_key_func
if isinstance(prefix, str):
prefix = functools.partial(prefix_factory, prefix)

def get_redis_key(self, session_id: str) -> str:
if self._redis_key_func:
return self._redis_key_func(session_id)
else:
return session_id
self.prefix: typing.Callable[[str], str] = prefix
self._connection: aioredis.Redis = connection or aioredis.from_url(url) # type: ignore[no-untyped-call]

async def read(self, session_id: str, lifetime: int) -> bytes:
value = await self._connection.get(self.get_redis_key(session_id))
value = await self._connection.get(self.prefix(session_id))
if value is None:
return b''
return value # type: ignore
Expand All @@ -49,12 +48,12 @@ async def write(self, session_id: str, data: bytes, lifetime: int) -> str:
# We cannot know the final session duration so set here something close to reality.
# FIXME: we want something better here
lifetime = max(lifetime, 3600) # 1h
await self._connection.set(self.get_redis_key(session_id), data, ex=lifetime)
await self._connection.set(self.prefix(session_id), data, ex=lifetime)
return session_id

async def remove(self, session_id: str) -> None:
await self._connection.delete(self.get_redis_key(session_id))
await self._connection.delete(self.prefix(session_id))

async def exists(self, session_id: str) -> bool:
result: int = await self._connection.exists(self.get_redis_key(session_id))
result: int = await self._connection.exists(self.prefix(session_id))
return result > 0
9 changes: 2 additions & 7 deletions tests/backends/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ def redis_key_callable(session_id: str) -> str:
return f"this:is:a:redis:key:{session_id}"


@pytest.fixture(params=[None, redis_key_callable], ids=["default", "using redis_key_callable"])
@pytest.fixture(params=['prefix_', redis_key_callable], ids=["using string", "using redis_key_callable"])
def redis_backend(request: SubRequest) -> SessionBackend:
redis_key = request.param
url = os.environ.get("REDIS_URL", "redis://localhost")
return RedisBackend(url, redis_key_func=redis_key)
return RedisBackend(url, prefix=redis_key)


@pytest.mark.asyncio
Expand Down Expand Up @@ -46,8 +46,3 @@ async def test_redis_exists(redis_backend: SessionBackend) -> None:
@pytest.mark.asyncio
async def test_redis_empty_session(redis_backend: SessionBackend) -> None:
assert await redis_backend.read("unknown_session_id", lifetime=60) == b''


def test_improperly_configured_redis_key() -> None:
with pytest.raises(Exception):
RedisBackend(redis_key_func="a_random_string") # type: ignore[arg-type]

0 comments on commit f36677a

Please sign in to comment.