Source code for aiogram.dispatcher.flags

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast, overload

from magic_filter import AttrDict, MagicFilter

if TYPE_CHECKING:
    from aiogram.dispatcher.event.handler import HandlerObject


@dataclass(frozen=True)
class Flag:
    name: str
    value: Any


@dataclass(frozen=True)
class FlagDecorator:
    flag: Flag

    @classmethod
    def _with_flag(cls, flag: Flag) -> "FlagDecorator":
        return cls(flag)

    def _with_value(self, value: Any) -> "FlagDecorator":
        new_flag = Flag(self.flag.name, value)
        return self._with_flag(new_flag)

    @overload
    def __call__(self, value: Callable[..., Any], /) -> Callable[..., Any]:  # type: ignore
        pass

    @overload
    def __call__(self, value: Any, /) -> "FlagDecorator":
        pass

    @overload
    def __call__(self, **kwargs: Any) -> "FlagDecorator":
        pass

    def __call__(
        self,
        value: Optional[Any] = None,
        **kwargs: Any,
    ) -> Union[Callable[..., Any], "FlagDecorator"]:
        if value and kwargs:
            raise ValueError("The arguments `value` and **kwargs can not be used together")

        if value is not None and callable(value):
            value.aiogram_flag = {
                **extract_flags_from_object(value),
                self.flag.name: self.flag.value,
            }
            return cast(Callable[..., Any], value)
        return self._with_value(AttrDict(kwargs) if value is None else value)


if TYPE_CHECKING:

    class _ChatActionFlagProtocol(FlagDecorator):
        def __call__(  # type: ignore[override]
            self,
            action: str = ...,
            interval: float = ...,
            initial_sleep: float = ...,
            **kwargs: Any,
        ) -> FlagDecorator:
            pass


class FlagGenerator:
    def __getattr__(self, name: str) -> FlagDecorator:
        if name[0] == "_":
            raise AttributeError("Flag name must NOT start with underscore")
        return FlagDecorator(Flag(name, True))

    if TYPE_CHECKING:
        chat_action: _ChatActionFlagProtocol


def extract_flags_from_object(obj: Any) -> Dict[str, Any]:
    if not hasattr(obj, "aiogram_flag"):
        return {}
    return cast(Dict[str, Any], obj.aiogram_flag)


[docs] def extract_flags(handler: Union["HandlerObject", Dict[str, Any]]) -> Dict[str, Any]: """ Extract flags from handler or middleware context data :param handler: handler object or data :return: dictionary with all handler flags """ if isinstance(handler, dict) and "handler" in handler: handler = handler["handler"] if hasattr(handler, "flags"): return handler.flags return {}
[docs] def get_flag( handler: Union["HandlerObject", Dict[str, Any]], name: str, *, default: Optional[Any] = None, ) -> Any: """ Get flag by name :param handler: handler object or data :param name: name of the flag :param default: default value (None) :return: value of the flag or default """ flags = extract_flags(handler) return flags.get(name, default)
[docs] def check_flags(handler: Union["HandlerObject", Dict[str, Any]], magic: MagicFilter) -> Any: """ Check flags via magic filter :param handler: handler object or data :param magic: instance of the magic :return: the result of magic filter check """ flags = extract_flags(handler) return magic.resolve(AttrDict(flags))