"""
This module has redis storage for finite-state machine based on `redis <https://pypi.org/project/redis/>`_ driver.
"""
import asyncio
import logging
import typing
from ...dispatcher.storage import BaseStorage
from ...utils import json
from ...utils.deprecated import deprecated
if typing.TYPE_CHECKING:
import aioredis
STATE_KEY = 'state'
STATE_DATA_KEY = 'data'
STATE_BUCKET_KEY = 'bucket'
class RedisStorage(BaseStorage):
"""
Simple Redis-base storage for FSM.
Usage:
.. code-block:: python3
storage = RedisStorage('localhost', 6379, db=5)
dp = Dispatcher(bot, storage=storage)
And need to close Redis connection when shutdown
.. code-block:: python3
await dp.storage.close()
await dp.storage.wait_closed()
"""
@deprecated("`RedisStorage` will be removed in aiogram v3.0. "
"Use `RedisStorage2` instead.", stacklevel=3)
def __init__(self, host='localhost', port=6379, db=None, password=None, ssl=None, loop=None, **kwargs):
self._host = host
self._port = port
self._db = db
self._password = password
self._ssl = ssl
self._kwargs = kwargs
self._redis: typing.Optional["aioredis.RedisConnection"] = None
self._connection_lock = asyncio.Lock()
async def close(self):
async with self._connection_lock:
if self._redis and not self._redis.closed:
self._redis.close()
async def wait_closed(self):
async with self._connection_lock:
if self._redis:
return await self._redis.wait_closed()
return True
async def redis(self) -> "aioredis.RedisConnection":
"""
Get Redis connection
"""
# Use thread-safe asyncio Lock because this method without that is not safe
import aioredis
async with self._connection_lock:
if self._redis is None or self._redis.closed:
self._redis = await aioredis.create_connection((self._host, self._port),
db=self._db, password=self._password, ssl=self._ssl,
**self._kwargs)
return self._redis
async def get_record(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None) -> typing.Dict:
"""
Get record from storage
:param chat:
:param user:
:return:
"""
chat, user = self.check_address(chat=chat, user=user)
addr = f"fsm:{chat}:{user}"
conn = await self.redis()
data = await conn.execute('GET', addr)
if data is None:
return {'state': None, 'data': {}}
return json.loads(data)
async def set_record(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
state=None, data=None, bucket=None):
"""
Write record to storage
:param bucket:
:param chat:
:param user:
:param state:
:param data:
:return:
"""
if data is None:
data = {}
if bucket is None:
bucket = {}
chat, user = self.check_address(chat=chat, user=user)
addr = f"fsm:{chat}:{user}"
conn = await self.redis()
if state is None and data == bucket == {}:
await conn.execute('DEL', addr)
else:
record = {'state': state, 'data': data, 'bucket': bucket}
await conn.execute('SET', addr, json.dumps(record))
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
record = await self.get_record(chat=chat, user=user)
return record.get('state', self.resolve_state(default))
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Dict:
record = await self.get_record(chat=chat, user=user)
return record['data']
async def set_state(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None):
record = await self.get_record(chat=chat, user=user)
state = self.resolve_state(state)
await self.set_record(chat=chat, user=user, state=state, data=record['data'])
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
record = await self.get_record(chat=chat, user=user)
await self.set_record(chat=chat, user=user, state=record['state'], data=data)
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None, **kwargs):
if data is None:
data = {}
record = await self.get_record(chat=chat, user=user)
record_data = record.get('data', {})
record_data.update(data, **kwargs)
await self.set_record(chat=chat, user=user, state=record['state'], data=record_data)
async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]:
"""
Get list of all stored chat's and user's
:return: list of tuples where first element is chat id and second is user id
"""
conn = await self.redis()
result = []
keys = await conn.execute('KEYS', 'fsm:*')
for item in keys:
*_, chat, user = item.decode('utf-8').split(':')
result.append((chat, user))
return result
async def reset_all(self, full=True):
"""
Reset states in DB
:param full: clean DB or clean only states
:return:
"""
conn = await self.redis()
if full:
await conn.execute('FLUSHDB')
else:
keys = await conn.execute('KEYS', 'fsm:*')
await conn.execute('DEL', *keys)
def has_bucket(self):
return True
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Dict:
record = await self.get_record(chat=chat, user=user)
return record.get('bucket', {})
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None):
record = await self.get_record(chat=chat, user=user)
await self.set_record(chat=chat, user=user, state=record['state'], data=record['data'], bucket=bucket)
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None, **kwargs):
record = await self.get_record(chat=chat, user=user)
record_bucket = record.get('bucket', {})
if bucket is None:
bucket = {}
record_bucket.update(bucket, **kwargs)
await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket)
[docs]class RedisStorage2(BaseStorage):
"""
Busted Redis-base storage for FSM.
Works with Redis connection pool and customizable keys prefix.
Usage:
.. code-block:: python3
storage = RedisStorage2('localhost', 6379, db=5, pool_size=10, prefix='my_fsm_key')
dp = Dispatcher(bot, storage=storage)
And need to close Redis connection when shutdown
.. code-block:: python3
await dp.storage.close()
"""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: typing.Optional[int] = None,
password: typing.Optional[str] = None,
ssl: typing.Optional[bool] = None,
pool_size: int = 10,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
prefix: str = "fsm",
state_ttl: typing.Optional[int] = None,
data_ttl: typing.Optional[int] = None,
bucket_ttl: typing.Optional[int] = None,
**kwargs,
):
from redis.asyncio import Redis
self._redis: typing.Optional[Redis] = Redis(
host=host,
port=port,
db=db,
password=password,
ssl=ssl,
max_connections=pool_size,
decode_responses=True,
**kwargs,
)
self._prefix = (prefix,)
self._state_ttl = state_ttl
self._data_ttl = data_ttl
self._bucket_ttl = bucket_ttl
@deprecated("This method will be removed in aiogram v3.0. "
"You should use your own instance of Redis.", stacklevel=3)
async def redis(self) -> "aioredis.Redis":
return self._redis
def generate_key(self, *parts):
return ':'.join(self._prefix + tuple(map(str, parts)))
async def close(self):
await self._redis.close()
async def wait_closed(self):
pass
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_KEY)
return await self._redis.get(key) or self.resolve_state(default)
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_DATA_KEY)
raw_result = await self._redis.get(key)
if raw_result:
return json.loads(raw_result)
return default or {}
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_KEY)
if state is None:
await self._redis.delete(key)
else:
await self._redis.set(key, self.resolve_state(state), ex=self._state_ttl)
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_DATA_KEY)
if data:
await self._redis.set(key, json.dumps(data), ex=self._data_ttl)
else:
await self._redis.delete(key)
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None, **kwargs):
if data is None:
data = {}
temp_data = await self.get_data(chat=chat, user=user, default={})
temp_data.update(data, **kwargs)
await self.set_data(chat=chat, user=user, data=temp_data)
def has_bucket(self):
return True
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
raw_result = await self._redis.get(key)
if raw_result:
return json.loads(raw_result)
return default or {}
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
if bucket:
await self._redis.set(key, json.dumps(bucket), ex=self._bucket_ttl)
else:
await self._redis.delete(key)
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None, **kwargs):
if bucket is None:
bucket = {}
temp_bucket = await self.get_bucket(chat=chat, user=user)
temp_bucket.update(bucket, **kwargs)
await self.set_bucket(chat=chat, user=user, bucket=temp_bucket)
async def reset_all(self, full=True):
"""
Reset states in DB
:param full: clean DB or clean only states
:return:
"""
if full:
await self._redis.flushdb()
else:
keys = await self._redis.keys(self.generate_key('*'))
await self._redis.delete(*keys)
async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]:
"""
Get list of all stored chat's and user's
:return: list of tuples where first element is chat id and second is user id
"""
result = []
keys = await self._redis.keys(self.generate_key('*', '*', STATE_KEY))
for item in keys:
*_, chat, user, _ = item.split(':')
result.append((chat, user))
return result
async def import_redis1(self, redis1):
await migrate_redis1_to_redis2(redis1, self)
async def migrate_redis1_to_redis2(storage1: RedisStorage, storage2: RedisStorage2):
"""
Helper for migrating from RedisStorage to RedisStorage2
:param storage1: instance of RedisStorage
:param storage2: instance of RedisStorage2
:return:
"""
if not isinstance(storage1, RedisStorage): # better than assertion
raise TypeError(f"{type(storage1)} is not RedisStorage instance.")
if not isinstance(storage2, RedisStorage):
raise TypeError(f"{type(storage2)} is not RedisStorage instance.")
log = logging.getLogger('aiogram.RedisStorage')
for chat, user in await storage1.get_states_list():
state = await storage1.get_state(chat=chat, user=user)
await storage2.set_state(chat=chat, user=user, state=state)
data = await storage1.get_data(chat=chat, user=user)
await storage2.set_data(chat=chat, user=user, data=data)
bucket = await storage1.get_bucket(chat=chat, user=user)
await storage2.set_bucket(chat=chat, user=user, bucket=bucket)
log.info(f"Migrated user {user} in chat {chat}")