Middleware and antiflood
middleware_and_antiflood.py
1import asyncio
2
3from aiogram import Bot, Dispatcher, executor, types
4from aiogram.contrib.fsm_storage.redis import RedisStorage2
5from aiogram.dispatcher import DEFAULT_RATE_LIMIT
6from aiogram.dispatcher.handler import CancelHandler, current_handler
7from aiogram.dispatcher.middlewares import BaseMiddleware
8from aiogram.utils.exceptions import Throttled
9
10TOKEN = 'BOT_TOKEN_HERE'
11
12# In this example Redis storage is used
13storage = RedisStorage2(db=5)
14
15bot = Bot(token=TOKEN)
16dp = Dispatcher(bot, storage=storage)
17
18
19def rate_limit(limit: int, key=None):
20 """
21 Decorator for configuring rate limit and key in different functions.
22
23 :param limit:
24 :param key:
25 :return:
26 """
27
28 def decorator(func):
29 setattr(func, 'throttling_rate_limit', limit)
30 if key:
31 setattr(func, 'throttling_key', key)
32 return func
33
34 return decorator
35
36
37class ThrottlingMiddleware(BaseMiddleware):
38 """
39 Simple middleware
40 """
41
42 def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'):
43 self.rate_limit = limit
44 self.prefix = key_prefix
45 super(ThrottlingMiddleware, self).__init__()
46
47 async def on_process_message(self, message: types.Message, data: dict):
48 """
49 This handler is called when dispatcher receives a message
50
51 :param message:
52 """
53 # Get current handler
54 handler = current_handler.get()
55
56 # Get dispatcher from context
57 dispatcher = Dispatcher.get_current()
58 # If handler was configured, get rate limit and key from handler
59 if handler:
60 limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
61 key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
62 else:
63 limit = self.rate_limit
64 key = f"{self.prefix}_message"
65
66 # Use Dispatcher.throttle method.
67 try:
68 await dispatcher.throttle(key, rate=limit)
69 except Throttled as t:
70 # Execute action
71 await self.message_throttled(message, t)
72
73 # Cancel current handler
74 raise CancelHandler()
75
76 async def message_throttled(self, message: types.Message, throttled: Throttled):
77 """
78 Notify user only on first exceed and notify about unlocking only on last exceed
79
80 :param message:
81 :param throttled:
82 """
83 handler = current_handler.get()
84 dispatcher = Dispatcher.get_current()
85 if handler:
86 key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
87 else:
88 key = f"{self.prefix}_message"
89
90 # Calculate how many time is left till the block ends
91 delta = throttled.rate - throttled.delta
92
93 # Prevent flooding
94 if throttled.exceeded_count <= 2:
95 await message.reply('Too many requests! ')
96
97 # Sleep.
98 await asyncio.sleep(delta)
99
100 # Check lock status
101 thr = await dispatcher.check_key(key)
102
103 # If current message is not last with current key - do not send message
104 if thr.exceeded_count == throttled.exceeded_count:
105 await message.reply('Unlocked.')
106
107
108@dp.message_handler(commands=['start'])
109@rate_limit(5, 'start') # this is not required but you can configure throttling manager for current handler using it
110async def cmd_test(message: types.Message):
111 # You can use this command every 5 seconds
112 await message.reply('Test passed! You can use this command every 5 seconds.')
113
114
115if __name__ == '__main__':
116 # Setup middleware
117 dp.middleware.setup(ThrottlingMiddleware())
118
119 # Start long-polling
120 executor.start_polling(dp)