Middleware and antiflood

middleware_and_antiflood.py
  1import asyncio
  2
  3from aiogram import Bot, Dispatcher, executor, types
  4from aiogram.contrib.fsm_storage.redis import RedisStorage2
  5from aiogram.dispatcher import DEFAULT_RATE_LIMIT
  6from aiogram.dispatcher.handler import CancelHandler, current_handler
  7from aiogram.dispatcher.middlewares import BaseMiddleware
  8from aiogram.utils.exceptions import Throttled
  9
 10TOKEN = 'BOT_TOKEN_HERE'
 11
 12# In this example Redis storage is used
 13storage = RedisStorage2(db=5)
 14
 15bot = Bot(token=TOKEN)
 16dp = Dispatcher(bot, storage=storage)
 17
 18
 19def rate_limit(limit: int, key=None):
 20    """
 21    Decorator for configuring rate limit and key in different functions.
 22
 23    :param limit:
 24    :param key:
 25    :return:
 26    """
 27
 28    def decorator(func):
 29        setattr(func, 'throttling_rate_limit', limit)
 30        if key:
 31            setattr(func, 'throttling_key', key)
 32        return func
 33
 34    return decorator
 35
 36
 37class ThrottlingMiddleware(BaseMiddleware):
 38    """
 39    Simple middleware
 40    """
 41
 42    def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'):
 43        self.rate_limit = limit
 44        self.prefix = key_prefix
 45        super(ThrottlingMiddleware, self).__init__()
 46
 47    async def on_process_message(self, message: types.Message, data: dict):
 48        """
 49        This handler is called when dispatcher receives a message
 50
 51        :param message:
 52        """
 53        # Get current handler
 54        handler = current_handler.get()
 55
 56        # Get dispatcher from context
 57        dispatcher = Dispatcher.get_current()
 58        # If handler was configured, get rate limit and key from handler
 59        if handler:
 60            limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
 61            key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
 62        else:
 63            limit = self.rate_limit
 64            key = f"{self.prefix}_message"
 65
 66        # Use Dispatcher.throttle method.
 67        try:
 68            await dispatcher.throttle(key, rate=limit)
 69        except Throttled as t:
 70            # Execute action
 71            await self.message_throttled(message, t)
 72
 73            # Cancel current handler
 74            raise CancelHandler()
 75
 76    async def message_throttled(self, message: types.Message, throttled: Throttled):
 77        """
 78        Notify user only on first exceed and notify about unlocking only on last exceed
 79
 80        :param message:
 81        :param throttled:
 82        """
 83        handler = current_handler.get()
 84        dispatcher = Dispatcher.get_current()
 85        if handler:
 86            key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
 87        else:
 88            key = f"{self.prefix}_message"
 89
 90        # Calculate how many time is left till the block ends
 91        delta = throttled.rate - throttled.delta
 92
 93        # Prevent flooding
 94        if throttled.exceeded_count <= 2:
 95            await message.reply('Too many requests! ')
 96
 97        # Sleep.
 98        await asyncio.sleep(delta)
 99
100        # Check lock status
101        thr = await dispatcher.check_key(key)
102
103        # If current message is not last with current key - do not send message
104        if thr.exceeded_count == throttled.exceeded_count:
105            await message.reply('Unlocked.')
106
107
108@dp.message_handler(commands=['start'])
109@rate_limit(5, 'start')  # this is not required but you can configure throttling manager for current handler using it
110async def cmd_test(message: types.Message):
111    # You can use this command every 5 seconds
112    await message.reply('Test passed! You can use this command every 5 seconds.')
113
114
115if __name__ == '__main__':
116    # Setup middleware
117    dp.middleware.setup(ThrottlingMiddleware())
118
119    # Start long-polling
120    executor.start_polling(dp)