from__future__importannotationsfromcollections.abcimportCallablefromdataclassesimportdataclassfromtypingimportTYPE_CHECKING,Any,cast,overloadfrommagic_filterimportAttrDict,MagicFilterifTYPE_CHECKING:fromaiogram.dispatcher.event.handlerimportHandlerObject@dataclass(frozen=True)classFlag:name:strvalue:Any@dataclass(frozen=True)classFlagDecorator:flag:Flag@classmethoddef_with_flag(cls,flag:Flag)->FlagDecorator:returncls(flag)def_with_value(self,value:Any)->FlagDecorator:new_flag=Flag(self.flag.name,value)returnself._with_flag(new_flag)@overloaddef__call__(self,value:Callable[...,Any],/)->Callable[...,Any]:# type: ignorepass@overloaddef__call__(self,value:Any,/)->FlagDecorator:pass@overloaddef__call__(self,**kwargs:Any)->FlagDecorator:passdef__call__(self,value:Any|None=None,**kwargs:Any,)->Callable[...,Any]|FlagDecorator:ifvalueandkwargs:msg="The arguments `value` and **kwargs can not be used together"raiseValueError(msg)ifvalueisnotNoneandcallable(value):value.aiogram_flag={**extract_flags_from_object(value),self.flag.name:self.flag.value,}returncast(Callable[...,Any],value)returnself._with_value(AttrDict(kwargs)ifvalueisNoneelsevalue)ifTYPE_CHECKING:class_ChatActionFlagProtocol(FlagDecorator):def__call__(# type: ignore[override]self,action:str=...,interval:float=...,initial_sleep:float=...,**kwargs:Any,)->FlagDecorator:passclassFlagGenerator:def__getattr__(self,name:str)->FlagDecorator:ifname[0]=="_":msg="Flag name must NOT start with underscore"raiseAttributeError(msg)returnFlagDecorator(Flag(name,True))ifTYPE_CHECKING:chat_action:_ChatActionFlagProtocoldefextract_flags_from_object(obj:Any)->dict[str,Any]:ifnothasattr(obj,"aiogram_flag"):return{}returncast(dict[str,Any],obj.aiogram_flag)
[docs]defextract_flags(handler:HandlerObject|dict[str,Any])->dict[str,Any]:""" Extract flags from handler or middleware context data :param handler: handler object or data :return: dictionary with all handler flags """ifisinstance(handler,dict)and"handler"inhandler:handler=handler["handler"]ifhasattr(handler,"flags"):returnhandler.flagsreturn{}
[docs]defget_flag(handler:HandlerObject|dict[str,Any],name:str,*,default:Any|None=None,)->Any:""" Get flag by name :param handler: handler object or data :param name: name of the flag :param default: default value (None) :return: value of the flag or default """flags=extract_flags(handler)returnflags.get(name,default)
[docs]defcheck_flags(handler:HandlerObject|dict[str,Any],magic:MagicFilter)->Any:""" Check flags via magic filter :param handler: handler object or data :param magic: instance of the magic :return: the result of magic filter check """flags=extract_flags(handler)returnmagic.resolve(AttrDict(flags))