Source code for aiointercept.core

import asyncio
import difflib
import inspect
import json as json_module
import logging
import pprint
import socket
import threading
import typing
import warnings
from collections.abc import Awaitable, Callable, Mapping, Sequence
from contextlib import suppress
from functools import wraps
from re import Pattern
from typing import Any, TypedDict, cast
from unittest import mock
from urllib.parse import parse_qs, urlencode

import aiohttp
from aiohttp import ClientRequest, ClientResponse, hdrs, web
from aiohttp.abc import ResolveResult
from aiohttp.connector import SSLContext, TCPConnector
from aiohttp.test_utils import TestServer
from yarl import URL

from .compat import merge_params, normalize_url

# Bounds for the diff body. ``ndiff`` is O(n²) in time and memory, so we clip
# both the per-line width (a raw body is a single very long line) and the line
# count *before* diffing to keep failures fast and readable on large values.
_MAX_DIFF_LINE_LEN = 200
_MAX_DIFF_LINES = 60


def _pformat(value: Any) -> list[str]:
    """Pretty-print *value* into a list of lines for diffing.

    ``bytes`` are decoded best-effort so the diff stays human-readable; other
    objects go through :func:`pprint.pformat` so nested dicts/lists align. Each
    line is clipped to :data:`_MAX_DIFF_LINE_LEN` so a huge single-line body
    can't push :func:`difflib.ndiff` into quadratic blow-up.
    """
    if isinstance(value, bytes):
        text = value.decode(errors="replace")
    elif isinstance(value, str):
        text = value
    else:
        text = pprint.pformat(value, width=88, sort_dicts=True)
    lines = text.splitlines() or [""]
    return [
        line if len(line) <= _MAX_DIFF_LINE_LEN else line[:_MAX_DIFF_LINE_LEN] + " …(line truncated)" for line in lines
    ]


def _oneline(value: Any, limit: int = 120) -> str:
    """Render *value* on a single line for the assertion's summary line."""
    text = " ".join("\n".join(_pformat(value)).split())
    return text if len(text) <= limit else text[: limit - 3] + "..."


def _diff(label: str, expected: Any, actual: Any) -> str:
    """Build an assertion message diffing *expected* vs *actual*.

    The first line is a compact ``expected ... got ...`` summary so it reads well
    in pytest's one-line failure list; the detailed :func:`difflib.ndiff` follows
    (``-`` marks expected, ``+`` marks actual, ``?`` points at differing chars).
    """
    summary = f"{label}: expected {_oneline(expected)}, got {_oneline(actual)}"
    expected_lines = _pformat(expected)
    actual_lines = _pformat(actual)
    clipped = max(len(expected_lines), len(actual_lines)) > _MAX_DIFF_LINES
    diff = difflib.ndiff(expected_lines[:_MAX_DIFF_LINES], actual_lines[:_MAX_DIFF_LINES])
    body = "\n".join(line.rstrip("\n") for line in diff)
    if not body:
        body = "(values render identically; check types)"
    if clipped:
        body += "\n... (diff truncated; values too large to show in full)"
    return f"{summary}\n{body}"


class AiointerceptRequestKwargs(TypedDict):
    headers: Mapping[str, str]
    query: Mapping[str, Sequence[str]]
    json: Any | None
    data: bytes | None


[docs] class AiointerceptRequest(web.Request): """A subclass of :class:`aiohttp.web.Request` captured by the mock server. Instances are the live request objects the test server received, augmented with the fields below so callbacks and assertions can inspect the body that was already consumed off the wire. They are stored in :attr:`aiointercept.requests` and passed to callbacks registered via :meth:`aiointercept.add`. Attributes: captured_body: The raw request body as bytes, read before dispatch. kwargs: A mapping with the parsed ``headers``, ``query`` (``dict`` of key → list of values), ``data`` (raw body or ``None`` if no body), `and ``json`` (decoded body, or ``None``) — the keyword arguments a callback receives. canonical_url: The normalized request URL with the original scheme restored (e.g. ``https://`` even though the test server received the request over plain HTTP). This is the URL passed to callbacks and used as the :attr:`aiointercept.requests` key. """ captured_body: bytes kwargs: AiointerceptRequestKwargs canonical_url: URL
[docs] @classmethod def upgrade( cls, request: web.Request, captured_body: bytes | None, kwargs: "AiointerceptRequestKwargs", canonical_url: URL, ) -> "AiointerceptRequest": request.__class__ = cls request.captured_body = captured_body request.kwargs = kwargs request.canonical_url = canonical_url return cast("AiointerceptRequest", request)
logger = logging.getLogger(__name__) _PROXY_REQ_DROP = frozenset(("host", "transfer-encoding", "x-aiointercept-orig-scheme")) # content-length must be dropped along with content-encoding: the proxy reads # the body decompressed, so the upstream value (compressed size) would truncate # the relayed response — aiohttp keeps an explicit Content-Length header as-is. _PROXY_RESP_DROP = frozenset(("transfer-encoding", "content-encoding", "content-length")) # Module-level state for class-level patches shared across concurrent instances. # Only the first entering instance installs the patches; the last exiting one removes them. _patch_lock = threading.Lock() _patch_refcount: int = 0 _real_ssl_context: Any = None _real_resolve_host: Any = None _real_get: Any = None _active_instances: "list[aiointercept]" = [] # Lock-free snapshot of _active_instances (innermost first), rebuilt on # install/remove so the per-connection hot paths (_resolution_target, # _shared_ssl_context) never take _patch_lock. _active_snapshot: "tuple[aiointercept, ...]" = () def _refresh_snapshot() -> None: """Rebuild :data:`_active_snapshot`. Callers must hold ``_patch_lock``.""" global _active_snapshot _active_snapshot = tuple(reversed(_active_instances)) def _make_resolve_result(host: str, inst: "aiointercept") -> "ResolveResult": return ResolveResult( hostname=host, host=inst.server_host, port=inst.server_port, family=socket.AF_INET, proto=0, flags=0, ) def _resolution_target(host: str) -> "aiointercept | None": """Return the mock instance whose server should serve ``host``, or ``None`` if the host should be resolved for real (passthrough, or no active mock).""" instances = _active_snapshot for inst in instances: if host in inst._host_list or inst._patterns_list: return inst for inst in instances: if host in inst._passthrough_hosts or inst.passthrough_unmatched: return None # No instance claims this host and none allow passthrough — redirect to the # innermost instance's server so the client gets a clear connection error. return instances[0] if instances else None async def _shared_resolve_host( connector_self: "TCPConnector", host: str, port: int, traces: Any = None, ) -> "list[ResolveResult]": # Single funnel for every connection's name resolution: intercepting here # covers ThreadedResolver, AsyncResolver, and any custom resolver, and it # sits *before* the connector's DNS cache — so a resolution cached before # the mock started can never leak a mocked host to the real network, while # passthrough hosts keep normal DNS caching. inst = _resolution_target(host) if inst is not None: return [_make_resolve_result(host, inst)] return cast("list[ResolveResult]", await _real_resolve_host(connector_self, host, port, traces)) async def _shared_get(connector_self: "TCPConnector", key: Any, traces: Any) -> Any: # A pooled keep-alive connection that does not point at the mock's server # bypasses DNS resolution entirely, so it would leak past the mock. For an # intercepted host, reuse a pooled connection only when it is already # connected to that mock's server; otherwise close it so aiohttp opens a # fresh connection, which then routes through patched resolution. inst = _resolution_target(key.host) if inst is None: return await _real_get(connector_self, key, traces) conn = await _real_get(connector_self, key, traces) if conn is None: return None transport = conn.transport peer = transport.get_extra_info("peername") if transport is not None else None if peer is not None and peer[0] == inst.server_host and peer[1] == inst.server_port: return conn conn.close() return None def _shared_ssl_context(connector_self: "TCPConnector", req: "ClientRequest") -> "SSLContext | None": instances = _active_snapshot host = req.url.raw_host url_str = str(req.url) for inst in instances: if host in inst._host_list or inst._match_pattern(url_str): if req.url.scheme == "https": req.headers["X-Aiointercept-Orig-Scheme"] = "https" return None for inst in instances: if inst._patterns_list and (inst.passthrough_unmatched or host in inst._passthrough_hosts): if req.url.scheme == "https": req.headers["X-Aiointercept-Orig-Scheme"] = "https" return None return cast("SSLContext | None", _real_ssl_context(connector_self, req)) def _install_patches(inst: "aiointercept") -> None: """Register ``inst`` and, on the first active instance, patch host resolution, SSL context, and connection reuse at the class level.""" global _patch_refcount, _real_ssl_context, _real_resolve_host, _real_get with _patch_lock: _active_instances.append(inst) _refresh_snapshot() if _patch_refcount == 0: _real_ssl_context = TCPConnector._get_ssl_context # pyright: ignore[reportPrivateUsage] _real_resolve_host = TCPConnector._resolve_host # pyright: ignore[reportPrivateUsage] _real_get = TCPConnector._get # pyright: ignore[reportPrivateUsage] TCPConnector._get_ssl_context = _shared_ssl_context # type: ignore[assignment] TCPConnector._resolve_host = _shared_resolve_host # type: ignore[assignment] TCPConnector._get = _shared_get # type: ignore[assignment] _patch_refcount += 1 def _remove_patches(inst: "aiointercept") -> None: """Deregister ``inst`` and, once the last active instance is gone, restore the original host resolution, SSL context, and connection reuse. Idempotent: a no-op if ``inst`` was never registered (e.g. ``start()`` failed before patching).""" global _patch_refcount, _real_ssl_context, _real_resolve_host, _real_get with _patch_lock: if inst not in _active_instances: return _active_instances.remove(inst) _refresh_snapshot() _patch_refcount -= 1 if _patch_refcount == 0: TCPConnector._get_ssl_context = _real_ssl_context # type: ignore[method-assign] TCPConnector._resolve_host = _real_resolve_host # type: ignore[method-assign] TCPConnector._get = _real_get # type: ignore[method-assign] _real_ssl_context = None _real_resolve_host = None _real_get = None
[docs] class CallbackResult: """Result object returned by a callback. Args: method: HTTP method (default GET; not used by the server handler). status: HTTP response status code. body: Raw response body as str or bytes. content_type: ``Content-Type`` header value. Set to ``None`` when *headers* already carries a ``Content-Type`` entry, to avoid colliding with it. payload: Response body as a dict; serialized to JSON automatically. headers: Additional response headers. response_class: Ignored (present for aioresponses API compatibility). reason: HTTP reason phrase. """ def __init__( self, method: str = hdrs.METH_GET, status: int = 200, body: str | bytes = "", content_type: str = "application/json", payload: Any = None, headers: Mapping[str, str] | None = None, response_class: type[ClientResponse] | None = None, reason: str | None = None, ): self.method = method self.status = status self.body = body self.payload = payload # Drop the default content_type when the caller already supplied a # Content-Type header (case-insensitively), to avoid the ValueError # aiohttp.web.Response raises when both are set. has_content_type_header = headers is not None and any(k.lower() == "content-type" for k in headers) self.content_type: str | None = None if has_content_type_header else content_type self.headers = headers self.response_class = response_class self.reason = reason
class _CloseConnection: """Sentinel: handler should close the transport, surfacing ClientConnectionError on the client.""" _CLOSE_CONNECTION = _CloseConnection() handler_type = Callable[[web.Request], Awaitable[web.StreamResponse]] | _CloseConnection
[docs] class MockResponse: """A registered mock handler returned by :meth:`~aiointercept.aiointercept.add` and the shorthand methods (:meth:`~aiointercept.aiointercept.get`, :meth:`~aiointercept.aiointercept.post`, etc.). Attributes: call_count: Number of times this mock has been matched and served. """ def __init__(self, handler: handler_type) -> None: self._handler = handler self.call_count: int = 0 async def __call__(self, request: web.Request) -> web.StreamResponse: self.call_count += 1 if self._handler is _CLOSE_CONNECTION: if request.transport: request.transport.close() return web.Response(status=502, text="Handler registered to raise ClientConnectionError.") fn = cast("Callable[[web.Request], Awaitable[web.StreamResponse]]", self._handler) return await fn(request)
[docs] class aiointercept: # noqa: N801 """Mock :mod:`aiohttp` requests by routing them through a real test server. Starts a real :class:`aiohttp.web.Application` on localhost and directs the client's requests to it. In the default mode you point the client at :attr:`server_url`; with ``mock_external_urls=True`` the DNS resolver is patched at the class level so any aiohttp request to a registered host is transparently intercepted. Register responses with :meth:`add` or the per-method shortcuts (:meth:`get`, :meth:`post`, ...), then inspect what was sent via :attr:`requests` and the ``assert_*`` helpers. Use it as an async context manager, as a decorator, or drive the lifecycle explicitly with :meth:`start` / :meth:`stop`:: async with aiointercept() as m: m.get(f"{m.server_url}/users", payload=[{"id": 1}]) async with aiohttp.ClientSession() as session: resp = await session.get(f"{m.server_url}/users") Attributes: server_url: Base URL of the running test server, e.g. ``"http://127.0.0.1:54321"``. Set once :meth:`start` has run. server_host: Host the test server is bound to. server_port: Port the test server is listening on. requests: Mapping of ``(METHOD, normalized URL)`` to the list of :class:`AiointerceptRequest` objects received, in arrival order. """ def __init__( self, mock_external_urls: bool = False, passthrough: Sequence[str] | None = None, passthrough_unmatched: bool = False, param: str | None = None, **kwargs: Any, ) -> None: """Create a mock. Args: mock_external_urls: When ``True``, patch the DNS resolver at the process level so requests to any registered host are intercepted, even those made by third-party libraries. When ``False`` (default), only requests sent to :attr:`server_url` are intercepted and no global state is modified. passthrough: Hosts whose requests bypass the mock and reach the real network. Requires ``mock_external_urls=True``. passthrough_unmatched: When ``True``, proxy every unmatched request to the real network instead of failing it. Requires ``mock_external_urls=True``. .. warning:: Proxied requests carry their original headers — including any ``Authorization`` or ``Cookie`` values — so a typo'd URL in a test can silently send real credentials to the real network. param: Keyword-argument name under which the mock is injected when the instance is used as a decorator. Defaults to appending it as the last positional argument. """ if kwargs: warnings.warn( "Passing extra parameters to aiointercept via kwargs is deprecated " "and will be removed in a future release.", DeprecationWarning, stacklevel=2, ) if passthrough_unmatched and not mock_external_urls: raise ValueError("passthrough_unmatched=True requires mock_external_urls=True") self._passthrough_urls = passthrough or [] self._passthrough_hosts: list[str] = [] self._mock_external_urls = mock_external_urls if mock_external_urls: for p in self._passthrough_urls: host = URL(p).host self._passthrough_hosts.append(host if host else p) self.param = param self.passthrough_unmatched = passthrough_unmatched self._host_list: set[str] = set() self._https_hosts: set[str] = set() self._patterns_list: list[Pattern[str]] = [] # handler are (path, method) → handler or list of handlers (if repeat != True) self.handlers: dict[tuple[str, str], MockResponse | list[MockResponse]] = {} # patterns_handler are (pattern, method) → handler or list of handlers (if repeat != True) self.patterns_handler: dict[tuple[Pattern[str], str], MockResponse | list[MockResponse]] = {} # recorded requests: {(METHOD, URL): [web.Request, ...]} self.requests: dict[tuple[str, URL], list[AiointerceptRequest]] = {} self.ordered_requests: list[tuple[tuple[str, URL], AiointerceptRequest]] = [] self.server: TestServer | None = None self._bypass_session: aiohttp.ClientSession | None = None # The intercept server runs on its own daemon thread + event loop so # that callers whose loop gets blocked (e.g. Starlette TestClient # holding the loop during a sync request) cannot deadlock the server. self._server_thread: threading.Thread | None = None self._server_loop: asyncio.AbstractEventLoop | None = None # The loop the caller entered the context manager on. Async user # callbacks are scheduled back onto this loop so they observe the # same loop-bound primitives (asyncio.Event, asyncio.Queue, ...) as # the rest of the test. self._caller_loop: asyncio.AbstractEventLoop | None = None async def __aenter__(self) -> "aiointercept": await self.start() return self
[docs] async def start(self) -> None: """Start the test server and install patches (if any). Called automatically by ``__aenter__``. Call it directly from a framework's setup hook (e.g. ``asyncSetUp``) when not using the context manager. After this returns, :attr:`server_url` is available. Pair every ``start()`` with a :meth:`stop`. """ self._caller_loop = asyncio.get_running_loop() self.requests.clear() self.ordered_requests.clear() await self._start_server_thread() assert self._server_loop is not None assert self.server is not None assert isinstance(self.server.host, str) assert isinstance(self.server.port, int) self.server_host = self.server.host self.server_port = self.server.port self.server_url = f"http://{self.server_host}:{self.server.port}" try: if self._mock_external_urls: _install_patches(self) except BaseException: _remove_patches(self) await self._stop_server_thread() raise
async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any, ) -> None: await self.stop()
[docs] async def stop(self) -> None: """Stop the test server and remove any DNS/SSL patches. Called automatically by ``__aexit__``. When DNS patching is active, the patches are only fully removed once the last concurrently-running instance stops. """ try: if self._mock_external_urls: _remove_patches(self) finally: await self._stop_server_thread() self._caller_loop = None self.handlers.clear() self.patterns_handler.clear() self._host_list.clear() self._https_hosts.clear() self._patterns_list.clear()
async def _start_server_thread(self) -> None: ready = threading.Event() startup_error: list[BaseException] = [] def _run_server_thread() -> None: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) self._server_loop = loop async def _start_server() -> None: app = web.Application() app.router.add_route("*", "/{tail:.*}", self._dispatch) self.server = TestServer(app) await self.server.start_server() try: loop.run_until_complete(_start_server()) except BaseException as e: startup_error.append(e) ready.set() loop.close() self._server_loop = None return ready.set() try: loop.run_forever() finally: # Drain anything still scheduled before closing the loop. try: pending = asyncio.all_tasks(loop) for task in pending: task.cancel() if pending: loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) finally: loop.close() self._server_thread = threading.Thread(target=_run_server_thread, name="aiointercept-server", daemon=True) self._server_thread.start() # ready.wait runs on the caller's loop thread; the server thread sets # the event quickly so a brief blocking wait here is acceptable. The # timeout keeps a wedged server thread from blocking the caller's loop # forever. if not ready.wait(timeout=30): # pragma: no cover - requires a wedged thread raise RuntimeError("aiointercept server thread failed to start within 30 seconds") if startup_error: self._server_thread.join() self._server_thread = None raise startup_error[0] async def _stop_server_thread(self) -> None: server_loop = self._server_loop if server_loop is not None and server_loop.is_running(): async def _teardown() -> None: if self._bypass_session is not None: await self._bypass_session.close() self._bypass_session = None if self.server is not None: await self.server.close() self.server = None try: await asyncio.wrap_future(asyncio.run_coroutine_threadsafe(_teardown(), server_loop)) finally: server_loop.call_soon_threadsafe(server_loop.stop) if self._server_thread is not None: self._server_thread.join() self._server_thread = None self._server_loop = None self._bypass_session = None self.server = None # Decorator support def __call__(self, f: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]: @wraps(f) async def wrapper(*args: Any, **kwargs: Any) -> Any: async with self as m: if self.param: kwargs[self.param] = m elif args and hasattr(args[0], f.__name__): args = (args[0], m, *args[1:]) else: args = (*args, m) return await f(*args, **kwargs) return wrapper def _make_bypass_session(self) -> aiohttp.ClientSession: _orig_resolve_host = _real_resolve_host _orig_ssl_ctx = _real_ssl_context _orig_get = _real_get class _BypassConnector(aiohttp.TCPConnector): """Connector wired to the saved originals so proxied requests resolve, pool, and encrypt against the real network.""" async def _resolve_host(self, host: str, port: int, traces: Any = None) -> list[ResolveResult]: return await _orig_resolve_host(self, host, port, traces) # type: ignore[no-any-return] def _get_ssl_context(self, req: ClientRequest) -> SSLContext | None: return _orig_ssl_ctx(self, req) # type: ignore[no-any-return] # pyright: ignore[reportPrivateUsage] async def _get(self, key: Any, traces: Any) -> Any: return await _orig_get(self, key, traces) # DummyCookieJar: the proxy must not accumulate Set-Cookie state across # proxied requests — cookie handling belongs to each test's own client. return aiohttp.ClientSession(connector=_BypassConnector(), cookie_jar=aiohttp.DummyCookieJar()) def _match_pattern(self, url: str) -> bool: return any(p.match(url) for p in self._patterns_list) def _diagnose_unmatched(self, method: str, url_str: str) -> str: request_line = f"{method} {url_str}" existing = self.handlers.get((url_str, method)) if isinstance(existing, list) and not existing: return ( f"aiointercept: no handler for {request_line} — " "all registered handlers for this method+url have already been consumed " "(finite repeat exhausted)" ) candidates = sorted({f"{m} {u}" for (u, m) in self.handlers}) close = difflib.get_close_matches(request_line, candidates, n=1, cutoff=0.0) if candidates else [] if close: diff = "\n".join(line.rstrip() for line in difflib.ndiff([request_line], [close[0]])) return f"aiointercept: no matching handler — closest registered (diff):\n{diff}" if self.patterns_handler: patterns = sorted({p.pattern for (p, _) in self.patterns_handler}) return f"aiointercept: no handler for {request_line} — no exact handlers; registered patterns: {patterns}" return f"aiointercept: no handler for {request_line} — no handlers registered." async def _dispatch(self, request: web.Request) -> web.StreamResponse: url = normalize_url(request.url) req_host = request.headers.get("Host", "") if request.headers.get("X-Aiointercept-Orig-Scheme") == "https": self._https_hosts.add(req_host) if req_host in self._https_hosts: url = url.with_scheme("https") key = (request.method.upper(), url) self.requests.setdefault(key, []) captured_body = None json = None if request.can_read_body: captured_body = await request.read() with suppress(Exception): json = json_module.loads(captured_body) if captured_body else None # this kwargs will be removed, should be deprecated in the future request_kwargs: AiointerceptRequestKwargs = { "headers": request.headers, # Use getall so duplicate keys (?a=1&a=2) aren't collapsed to one value. "query": {k: request.query.getall(k) for k in dict.fromkeys(request.query)}, "json": json, "data": captured_body, } aiointercept_request = AiointerceptRequest.upgrade(request, captured_body, request_kwargs, url) # Read body eagerly before the handler runs, because aiohttp sets # PayloadAccessError on the stream once the response cycle completes. self.requests[key].append(aiointercept_request) self.ordered_requests.append((key, aiointercept_request)) url_str = str(url) selected_handler = self.handlers.get((url_str, request.method)) if isinstance(selected_handler, list): if not selected_handler: handler: MockResponse | None = None else: handler = selected_handler.pop(0) else: handler = selected_handler original_host = request.headers.get("Host", request.url.host) if handler is None: original_urls = [ f"https://{original_host}{request.path_qs}", f"http://{original_host}{request.path_qs}", ] for (pattern, method), pattern_handler in self.patterns_handler.items(): if any(pattern.match(u) for u in original_urls) and method == request.method: if isinstance(pattern_handler, list): handler = pattern_handler[0] remaining = pattern_handler[1:] if remaining: self.patterns_handler[pattern, request.method] = remaining else: del self.patterns_handler[pattern, request.method] else: handler = pattern_handler break if handler is None: if self._mock_external_urls and self.passthrough_unmatched: scheme = request.headers.get("X-Aiointercept-Orig-Scheme") or ("https" if request.secure else "http") real_url = f"{scheme}://{original_host}{request.path_qs}" session = self._bypass_session if session is None: # Created lazily on first passthrough use. _dispatch runs on # the server loop, which is the loop the session must bind # to, and there is no await between check and assignment. session = self._bypass_session = self._make_bypass_session() async with session.request( method=request.method, url=real_url, headers={k: v for k, v in request.headers.items() if k.lower() not in _PROXY_REQ_DROP}, data=getattr(request, "captured_body", None) or None, # Relay redirects untouched so the *client's* redirect # handling (allow_redirects, history) behaves as it would # against the real network. allow_redirects=False, ssl=True, ) as real_resp: body = await real_resp.read() return web.Response( status=real_resp.status, headers={k: v for k, v in real_resp.headers.items() if k.lower() not in _PROXY_RESP_DROP}, body=body, ) warning_msg = self._diagnose_unmatched(request.method.upper(), str(url)) logger.warning(warning_msg) # this should raise ClientConnectionError on the other side if request.transport: request.transport.close() # Fallback in case transport.close() didn't take effect — the client # should normally see ClientConnectionError before reading this body. return web.Response(status=502, text="No handler registered for this request.") assert handler is not None return await handler(request)
[docs] def add( self, url: URL | str | Pattern[str], method: str = hdrs.METH_GET, status: int = 200, body: str | bytes = b"", json: Any = None, payload: Any = None, headers: Mapping[str, str] | None = None, repeat: bool | int = False, content_type: str | None = None, callback: Callable[..., CallbackResult | Awaitable[CallbackResult]] | None = None, reason: str | None = None, exception: Exception | bool | None = None, ) -> "MockResponse": """Register a mock handler for *url* and *method*. Args: url: Target URL as str, :class:`~yarl.URL`, or compiled :class:`re.Pattern`. method: HTTP method (case-insensitive, default ``GET``). status: Response status code. body: Raw response body (str is UTF-8 encoded; default empty bytes). json: Response body as a JSON-serialisable object (overrides *body*). payload: Alias for *json*. headers: Additional response headers. repeat: ``True`` to respond indefinitely; integer N to respond N times; ``False`` or ``0`` to respond once (default). content_type: Override the ``Content-Type`` response header. callback: Callable ``(url, *, headers, query, json) → CallbackResult`` (sync or async). Takes precedence over *body* / *json* / *status*. reason: HTTP reason phrase. exception: Any truthy value registers a handler that closes the connection, causing :class:`~aiohttp.ClientConnectionError` on the client. Passing a specific exception instance logs a warning; pass ``exception=True`` to suppress it. """ if exception and exception is not True: logger.warning( "aiointercept only raise ClientConnectionError, pass exception=True instead of an specific exception" ) method = method.upper() if isinstance(url, str): url = URL(url) if self.server is None: raise RuntimeError("Server not started — use `async with aiointercept() as m:` first.") if isinstance(url, Pattern): self._patterns_list.append(url) if isinstance(url, URL): host = url.host if not host: raise ValueError(f"Cannot extract host from {url!r}") # Map this host → our test server self._host_list.add(host) # Record HTTPS-ness eagerly so it survives clear() + re-add. The # SSL-context hook also populates _https_hosts, but only on fresh # TCP connections; keep-alive connections that outlive a clear() # would otherwise lose their scheme tagging. if url.scheme == "https": self._https_hosts.add(host) if json is not None: body = json_module.dumps(json).encode() elif payload is not None: body = json_module.dumps(payload).encode() elif isinstance(body, str): body = body.encode() resp_headers = dict(headers or {}) if not content_type and not any(k.lower() == "content-type" for k in resp_headers): content_type = "application/json" async def handler(request: web.Request) -> web.Response: if callable(callback): aiointercept_request = cast("AiointerceptRequest", request) cb_kwargs = aiointercept_request.kwargs # Use the canonical URL (normalized, original scheme restored) # rather than request.url, which is always http:// here because # the test server receives the request over plain HTTP. cb_url = aiointercept_request.canonical_url if inspect.iscoroutinefunction(callback): # Async callbacks run on the caller's loop so that # loop-bound primitives (asyncio.Event, asyncio.Queue, # asyncio.Lock created by the test) keep working. caller_loop = self._caller_loop if ( caller_loop is not None and not caller_loop.is_closed() and caller_loop is not asyncio.get_running_loop() ): result = await asyncio.wrap_future( asyncio.run_coroutine_threadsafe(callback(cb_url, **cb_kwargs), caller_loop) ) else: result = await callback(cb_url, **cb_kwargs) else: result = callback(cb_url, **cb_kwargs) _status = result.status _body = result.body _headers = result.headers or {} if result.payload is not None: _body = json_module.dumps(result.payload).encode() _content_type = result.content_type _reason = result.reason else: _status = status _body = body _headers = headers _content_type = content_type _reason = reason return web.Response( status=_status, body=_body, headers=_headers, reason=_reason, content_type=_content_type, ) raw: handler_type = handler if not exception else _CLOSE_CONNECTION mock_response = MockResponse(raw) if repeat is True: if isinstance(url, Pattern): self.patterns_handler[url, method] = mock_response return mock_response handler_url = str(normalize_url(url)) self.handlers[handler_url, method] = mock_response else: if repeat is False or repeat == 0: repeat = 1 if repeat < 1: raise ValueError("repeat must be at least 1") handlers: list[MockResponse] = [mock_response] * repeat if isinstance(url, Pattern): if (url, method) in self.patterns_handler: list_pattern_handler = self.patterns_handler[(url, method)] if isinstance(list_pattern_handler, list): list_pattern_handler += handlers else: raise ValueError( f"Existing handler for pattern {url} {method} has " "repeat=True, cannot add more handlers to it." ) else: self.patterns_handler[url, method] = handlers return mock_response handler_url = str(normalize_url(url)) if (handler_url, method) in self.handlers: handlers_list = self.handlers[(handler_url, method)] if isinstance(handlers_list, list): handlers_list += handlers else: raise ValueError( f"Existing handler for {handler_url} {method} has repeat=True, cannot add more handlers to it." ) else: self.handlers[handler_url, method] = handlers return mock_response
[docs] def get(self, url: "URL | str | Pattern[str]", **kwargs: Any) -> "MockResponse": """Register a mock GET handler. See :meth:`add` for all keyword arguments.""" return self.add(url, method=hdrs.METH_GET, **kwargs)
[docs] def post(self, url: "URL | str | Pattern[str]", **kwargs: Any) -> "MockResponse": """Register a mock POST handler. See :meth:`add` for all keyword arguments.""" return self.add(url, method=hdrs.METH_POST, **kwargs)
[docs] def put(self, url: "URL | str | Pattern[str]", **kwargs: Any) -> "MockResponse": """Register a mock PUT handler. See :meth:`add` for all keyword arguments.""" return self.add(url, method=hdrs.METH_PUT, **kwargs)
[docs] def patch(self, url: "URL | str | Pattern[str]", **kwargs: Any) -> "MockResponse": """Register a mock PATCH handler. See :meth:`add` for all keyword arguments.""" return self.add(url, method=hdrs.METH_PATCH, **kwargs)
[docs] def delete(self, url: "URL | str | Pattern[str]", **kwargs: Any) -> "MockResponse": """Register a mock DELETE handler. See :meth:`add` for all keyword arguments.""" return self.add(url, method=hdrs.METH_DELETE, **kwargs)
[docs] def head(self, url: "URL | str | Pattern[str]", **kwargs: Any) -> "MockResponse": """Register a mock HEAD handler. See :meth:`add` for all keyword arguments.""" return self.add(url, method=hdrs.METH_HEAD, **kwargs)
[docs] def options(self, url: "URL | str | Pattern[str]", **kwargs: Any) -> "MockResponse": """Register a mock OPTIONS handler. See :meth:`add` for all keyword arguments.""" return self.add(url, method=hdrs.METH_OPTIONS, **kwargs)
[docs] def clear(self) -> None: """Clear all recorded requests and registered handlers.""" self.requests.clear() self.ordered_requests.clear() self.handlers.clear() self.patterns_handler.clear() self._host_list.clear() self._patterns_list.clear() self._https_hosts.clear()
[docs] def assert_called(self) -> None: """Assert that at least one request was made.""" if not self.requests: raise AssertionError("Expected at least one call, got none.")
[docs] def assert_not_called(self) -> None: """Assert that no requests were made.""" if self.requests: raise AssertionError(f"Expected no calls, got {len(self.ordered_requests)}")
[docs] def assert_called_once(self) -> None: """Assert that exactly one request was made across all URLs.""" count = len(self.ordered_requests) if count != 1: raise AssertionError(f"Expected exactly 1 call, got {count}.")
@property def call_count(self) -> int: """Total number of requests intercepted across all URLs.""" return len(self.ordered_requests) @property def last_request(self) -> "AiointerceptRequest | None": """The most recently intercepted request, or ``None`` if no requests have been made.""" if not self.ordered_requests: return None return self.ordered_requests[-1][1] def _no_calls_message(self, method: str, url: URL) -> str: """Build a 'No calls to ...' message, suggesting similar recorded calls.""" msg = f"No calls to {method.upper()} {url}" recorded = [f"{m} {u}" for (m, u) in self.requests] if not recorded: return f"{msg}\nNo requests were recorded." target = f"{method.upper()} {url}" similar = difflib.get_close_matches(target, recorded, n=3, cutoff=0) heading = "Did you mean one of" listing = "\n".join(f" - {s}" for s in similar) return f"{msg}\n{heading}:\n{listing}"
[docs] def assert_any_call( self, url: URL | str, method: str = hdrs.METH_GET, params: Mapping[str, str] | None = None, ) -> None: """Assert that *url* was called at least once with the given *method*.""" url = normalize_url(merge_params(url, params)) key = (method.upper(), url) if key not in self.requests: raise AssertionError(self._no_calls_message(method, url))
[docs] def assert_called_with( self, url: URL | str, method: str = hdrs.METH_GET, params: typing.Mapping[str, str] | None = None, data: str | bytes | typing.Mapping[str, Any] | None = None, json: Any = None, headers: typing.Mapping[str, str] | None = None, strict_headers: bool = False, **kwargs: Any, ) -> None: """Assert that the most recent call to *url* matched the given arguments. Args: url: Expected URL (str or :class:`~yarl.URL`). method: Expected HTTP method (default ``GET``). params: Query string params merged into *url* before lookup. data: Expected request body — bytes, str, or a dict (form-encoded via ``application/x-www-form-urlencoded``). json: Expected request body as a JSON-serialisable object. headers: Expected request headers. By default only the headers listed here are checked; auto-added aiohttp headers are ignored. Set *strict_headers* to compare the full header map. strict_headers: When ``True``, the complete set of actual request headers must match *headers* exactly. Use :data:`unittest.mock.ANY` as a value to accept any value for a specific key (e.g. ``Content-Length``). kwargs: Ignored (present for aioresponses API compatibility). """ if kwargs: warnings.warn( "Passing extra parameters to assert_called_with via kwargs is " "deprecated and will be removed in a future release.", DeprecationWarning, stacklevel=2, ) url = normalize_url(merge_params(url, params)) key = (method.upper(), url) if key not in self.requests: raise AssertionError(self._no_calls_message(method, url)) request: AiointerceptRequest = self.requests[key][-1] actual_body = request.captured_body if json is not None and json is not mock.ANY: # aiohttp sends json= as JSON-encoded bytes with application/json actual_body_str = actual_body.decode(errors="replace") try: actual_json = json_module.loads(actual_body_str) except Exception as exc: raise AssertionError(_diff("Expected JSON body, got non-JSON body", json, actual_body)) from exc assert actual_json == json, _diff("JSON body mismatch", json, actual_json) elif data is not None and data is not mock.ANY: if not isinstance(data, (str, bytes)): actual_ct = request.headers.get("Content-Type", "") if actual_ct and "application/x-www-form-urlencoded" not in actual_ct: raise AssertionError( f"data=dict assertion requires Content-Type: " f"application/x-www-form-urlencoded, got {actual_ct!r}. " f"Use json= for JSON bodies." ) actual_qs = parse_qs(actual_body.decode(errors="replace")) expected_qs = parse_qs(urlencode(sorted(data.items()))) assert actual_qs == expected_qs, _diff("Form-encoded body mismatch", expected_qs, actual_qs) else: expected_body = data.encode() if isinstance(data, str) else data assert actual_body == expected_body, _diff("Body mismatch", expected_body, actual_body) if strict_headers: actual_headers = dict(request.headers) actual_headers.pop("x-aiointercept-orig-scheme", None) expected_headers = headers or {} assert expected_headers == actual_headers, _diff( "Headers mismatch", dict(expected_headers), dict(actual_headers) ) elif headers and headers is not mock.ANY: actual_headers_proxy = request.headers for k, v in headers.items(): assert actual_headers_proxy.get(k) == v, ( f"Header {k!r}: expected {v!r}, got {actual_headers_proxy.get(k)!r}" )
[docs] def assert_called_once_with( self, url: URL | str, method: str = hdrs.METH_GET, params: typing.Mapping[str, str] | None = None, data: str | bytes | typing.Mapping[str, Any] | None = None, json: Any = None, headers: typing.Mapping[str, str] | None = None, strict_headers: bool = False, **kwargs: Any, ) -> None: """Assert that exactly one request was made and it matched the given arguments.""" self.assert_called_once() self.assert_called_with(url, method, params, data, json, headers, strict_headers, **kwargs)