Source code for aiogram.client.session.aiohttp

from __future__ import annotations

import asyncio
import ssl
from typing import (
    TYPE_CHECKING,
    Any,
    AsyncGenerator,
    Dict,
    Iterable,
    List,
    Optional,
    Tuple,
    Type,
    Union,
    cast,
)

import certifi
from aiohttp import BasicAuth, ClientError, ClientSession, FormData, TCPConnector
from aiohttp.hdrs import USER_AGENT
from aiohttp.http import SERVER_SOFTWARE

from aiogram.__meta__ import __version__
from aiogram.methods import TelegramMethod

from ...exceptions import TelegramNetworkError
from ...methods.base import TelegramType
from ...types import InputFile
from .base import BaseSession

if TYPE_CHECKING:
    from ..bot import Bot

_ProxyBasic = Union[str, Tuple[str, BasicAuth]]
_ProxyChain = Iterable[_ProxyBasic]
_ProxyType = Union[_ProxyChain, _ProxyBasic]


def _retrieve_basic(basic: _ProxyBasic) -> Dict[str, Any]:
    from aiohttp_socks.utils import parse_proxy_url  # type: ignore

    proxy_auth: Optional[BasicAuth] = None

    if isinstance(basic, str):
        proxy_url = basic
    else:
        proxy_url, proxy_auth = basic

    proxy_type, host, port, username, password = parse_proxy_url(proxy_url)
    if isinstance(proxy_auth, BasicAuth):
        username = proxy_auth.login
        password = proxy_auth.password

    return {
        "proxy_type": proxy_type,
        "host": host,
        "port": port,
        "username": username,
        "password": password,
        "rdns": True,
    }


def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"], Dict[str, Any]]:
    from aiohttp_socks import (  # type: ignore
        ChainProxyConnector,
        ProxyConnector,
        ProxyInfo,
    )

    # since tuple is Iterable(compatible with _ProxyChain) object, we assume that
    # user wants chained proxies if tuple is a pair of string(url) and BasicAuth
    if isinstance(chain_or_plain, str) or (
        isinstance(chain_or_plain, tuple) and len(chain_or_plain) == 2
    ):
        chain_or_plain = cast(_ProxyBasic, chain_or_plain)
        return ProxyConnector, _retrieve_basic(chain_or_plain)

    chain_or_plain = cast(_ProxyChain, chain_or_plain)
    infos: List[ProxyInfo] = []
    for basic in chain_or_plain:
        infos.append(ProxyInfo(**_retrieve_basic(basic)))

    return ChainProxyConnector, {"proxy_infos": infos}


[docs] class AiohttpSession(BaseSession): def __init__(self, proxy: Optional[_ProxyType] = None, **kwargs: Any) -> None: super().__init__(**kwargs) self._session: Optional[ClientSession] = None self._connector_type: Type[TCPConnector] = TCPConnector self._connector_init: Dict[str, Any] = { "ssl": ssl.create_default_context(cafile=certifi.where()), } self._should_reset_connector = True # flag determines connector state self._proxy: Optional[_ProxyType] = None if proxy is not None: try: self._setup_proxy_connector(proxy) except ImportError as exc: # pragma: no cover raise RuntimeError( "In order to use aiohttp client for proxy requests, install " "https://pypi.org/project/aiohttp-socks/" ) from exc def _setup_proxy_connector(self, proxy: _ProxyType) -> None: self._connector_type, self._connector_init = _prepare_connector(proxy) self._proxy = proxy @property def proxy(self) -> Optional[_ProxyType]: return self._proxy @proxy.setter def proxy(self, proxy: _ProxyType) -> None: self._setup_proxy_connector(proxy) self._should_reset_connector = True async def create_session(self) -> ClientSession: if self._should_reset_connector: await self.close() if self._session is None or self._session.closed: self._session = ClientSession( connector=self._connector_type(**self._connector_init), headers={ USER_AGENT: f"{SERVER_SOFTWARE} aiogram/{__version__}", }, ) self._should_reset_connector = False return self._session async def close(self) -> None: if self._session is not None and not self._session.closed: await self._session.close() # Wait 250 ms for the underlying SSL connections to close # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown await asyncio.sleep(0.25) def build_form_data(self, bot: Bot, method: TelegramMethod[TelegramType]) -> FormData: form = FormData(quote_fields=False) files: Dict[str, InputFile] = {} for key, value in method.model_dump(warnings=False).items(): value = self.prepare_value(value, bot=bot, files=files) if not value: continue form.add_field(key, value) for key, value in files.items(): form.add_field( key, value.read(bot), filename=value.filename or key, ) return form async def make_request( self, bot: Bot, method: TelegramMethod[TelegramType], timeout: Optional[int] = None ) -> TelegramType: session = await self.create_session() url = self.api.api_url(token=bot.token, method=method.__api_method__) form = self.build_form_data(bot=bot, method=method) try: async with session.post( url, data=form, timeout=self.timeout if timeout is None else timeout ) as resp: raw_result = await resp.text() except asyncio.TimeoutError: raise TelegramNetworkError(method=method, message="Request timeout error") except ClientError as e: raise TelegramNetworkError(method=method, message=f"{type(e).__name__}: {e}") response = self.check_response( bot=bot, method=method, status_code=resp.status, content=raw_result ) return cast(TelegramType, response.result) async def stream_content( self, url: str, headers: Optional[Dict[str, Any]] = None, timeout: int = 30, chunk_size: int = 65536, raise_for_status: bool = True, ) -> AsyncGenerator[bytes, None]: if headers is None: headers = {} session = await self.create_session() async with session.get( url, timeout=timeout, headers=headers, raise_for_status=raise_for_status ) as resp: async for chunk in resp.content.iter_chunked(chunk_size): yield chunk async def __aenter__(self) -> AiohttpSession: await self.create_session() return self