Source code for aiogram.fsm.storage.redis

import json
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Callable, Dict, Literal, Optional, cast

from redis.asyncio.client import Redis
from redis.asyncio.connection import ConnectionPool
from redis.asyncio.lock import Lock
from redis.typing import ExpiryT

from aiogram.fsm.state import State
from aiogram.fsm.storage.base import (
    DEFAULT_DESTINY,
    BaseEventIsolation,
    BaseStorage,
    StateType,
    StorageKey,
)

DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
_JsonLoads = Callable[..., Any]
_JsonDumps = Callable[..., str]


[docs] class KeyBuilder(ABC): """ Base class for Redis key builder """
[docs] @abstractmethod def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str: """ This method should be implemented in subclasses :param key: contextual key :param part: part of the record :return: key to be used in Redis queries """ pass
[docs] class DefaultKeyBuilder(KeyBuilder): """ Simple Redis key builder with default prefix. Generates a colon-joined string with prefix, chat_id, user_id, optional bot_id, business_connection_id and destiny. Format: :code:`<prefix>:<bot_id?>:<business_connection_id?>:<chat_id>:<user_id>:<destiny?>:<field>` """ def __init__( self, *, prefix: str = "fsm", separator: str = ":", with_bot_id: bool = False, with_business_connection_id: bool = False, with_destiny: bool = False, ) -> None: """ :param prefix: prefix for all records :param separator: separator :param with_bot_id: include Bot id in the key :param with_business_connection_id: include business connection id :param with_destiny: include a destiny key """ self.prefix = prefix self.separator = separator self.with_bot_id = with_bot_id self.with_business_connection_id = with_business_connection_id self.with_destiny = with_destiny
[docs] def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str: parts = [self.prefix] if self.with_bot_id: parts.append(str(key.bot_id)) if self.with_business_connection_id and key.business_connection_id: parts.append(str(key.business_connection_id)) parts.append(str(key.chat_id)) if key.thread_id: parts.append(str(key.thread_id)) parts.append(str(key.user_id)) if self.with_destiny: parts.append(key.destiny) elif key.destiny != DEFAULT_DESTINY: raise ValueError( "Redis key builder is not configured to use key destiny other the default.\n" "\n" "Probably, you should set `with_destiny=True` in for DefaultKeyBuilder.\n" "E.g: `RedisStorage(redis, key_builder=DefaultKeyBuilder(with_destiny=True))`" ) parts.append(part) return self.separator.join(parts)
[docs] class RedisStorage(BaseStorage): """ Redis storage required :code:`redis` package installed (:code:`pip install redis`) """
[docs] def __init__( self, redis: Redis, key_builder: Optional[KeyBuilder] = None, state_ttl: Optional[ExpiryT] = None, data_ttl: Optional[ExpiryT] = None, json_loads: _JsonLoads = json.loads, json_dumps: _JsonDumps = json.dumps, ) -> None: """ :param redis: Instance of Redis connection :param key_builder: builder that helps to convert contextual key to string :param state_ttl: TTL for state records :param data_ttl: TTL for data records """ if key_builder is None: key_builder = DefaultKeyBuilder() self.redis = redis self.key_builder = key_builder self.state_ttl = state_ttl self.data_ttl = data_ttl self.json_loads = json_loads self.json_dumps = json_dumps
[docs] @classmethod def from_url( cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> "RedisStorage": """ Create an instance of :class:`RedisStorage` with specifying the connection string :param url: for example :code:`redis://user:password@host:port/db` :param connection_kwargs: see :code:`redis` docs :param kwargs: arguments to be passed to :class:`RedisStorage` :return: an instance of :class:`RedisStorage` """ if connection_kwargs is None: connection_kwargs = {} pool = ConnectionPool.from_url(url, **connection_kwargs) redis = Redis(connection_pool=pool) return cls(redis=redis, **kwargs)
def create_isolation(self, **kwargs: Any) -> "RedisEventIsolation": return RedisEventIsolation(redis=self.redis, key_builder=self.key_builder, **kwargs) async def close(self) -> None: await self.redis.aclose(close_connection_pool=True) async def set_state( self, key: StorageKey, state: StateType = None, ) -> None: redis_key = self.key_builder.build(key, "state") if state is None: await self.redis.delete(redis_key) else: await self.redis.set( redis_key, cast(str, state.state if isinstance(state, State) else state), ex=self.state_ttl, ) async def get_state( self, key: StorageKey, ) -> Optional[str]: redis_key = self.key_builder.build(key, "state") value = await self.redis.get(redis_key) if isinstance(value, bytes): return value.decode("utf-8") return cast(Optional[str], value) async def set_data( self, key: StorageKey, data: Dict[str, Any], ) -> None: redis_key = self.key_builder.build(key, "data") if not data: await self.redis.delete(redis_key) return await self.redis.set( redis_key, self.json_dumps(data), ex=self.data_ttl, ) async def get_data( self, key: StorageKey, ) -> Dict[str, Any]: redis_key = self.key_builder.build(key, "data") value = await self.redis.get(redis_key) if value is None: return {} if isinstance(value, bytes): value = value.decode("utf-8") return cast(Dict[str, Any], self.json_loads(value))
class RedisEventIsolation(BaseEventIsolation): def __init__( self, redis: Redis, key_builder: Optional[KeyBuilder] = None, lock_kwargs: Optional[Dict[str, Any]] = None, ) -> None: if key_builder is None: key_builder = DefaultKeyBuilder() if lock_kwargs is None: lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS self.redis = redis self.key_builder = key_builder self.lock_kwargs = lock_kwargs @classmethod def from_url( cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> "RedisEventIsolation": if connection_kwargs is None: connection_kwargs = {} pool = ConnectionPool.from_url(url, **connection_kwargs) redis = Redis(connection_pool=pool) return cls(redis=redis, **kwargs) @asynccontextmanager async def lock( self, key: StorageKey, ) -> AsyncGenerator[None, None]: redis_key = self.key_builder.build(key, "lock") async with self.redis.lock(name=redis_key, **self.lock_kwargs, lock_class=Lock): yield None async def close(self) -> None: pass