Multipart parser

BáiZé provides a “bring your own I/O” multipart parser with excellent performance.

Example

Synchronous example

import json
from itertools import chain
from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union
from urllib.parse import parse_qsl

from baize import multipart
from baize.datastructures import (
    URL,
    Address,
    FormData,
    Headers,
    QueryParams,
    UploadFile,
)
from baize.exceptions import HTTPException
from baize.requests import MoreInfoFromHeaderMixin
from baize.typing import Environ
from baize.utils import cached_property


class HTTPConnection(Mapping[str, Any], MoreInfoFromHeaderMixin):
    """
    A base class for incoming HTTP connections.

    It is a valid Mapping type that allows you to directly
    access the values in any WSGI `environ` dictionary.
    """

    def __init__(self, environ: Environ) -> None:
        self._environ = environ
        self._stream_consumed = False

    def __getitem__(self, key: str) -> Any:
        return self._environ[key]

    def __iter__(self) -> Iterator[str]:
        return iter(self._environ)

    def __len__(self) -> int:
        return len(self._environ)

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, self.__class__):
            return NotImplemented
        return self._environ == other._environ

    @cached_property
    def client(self) -> Address:
        """
        Client's IP and Port.

        Note that this depends on the `REMOTE_ADDR` and `REMOTE_PORT` values
        given by the WSGI Server, and is not necessarily accurate.
        """
        if self.get("REMOTE_ADDR") and self.get("REMOTE_PORT"):
            return Address(self["REMOTE_ADDR"], int(self["REMOTE_PORT"]))
        return Address(host=None, port=None)

    @cached_property
    def url(self) -> URL:
        """
        The full URL of this request.
        """
        return URL(environ=self._environ)

    @cached_property
    def path_params(self) -> Dict[str, Any]:
        """
        The path parameters parsed by the framework.
        """
        return self.get("PATH_PARAMS", {})

    @cached_property
    def query_params(self) -> QueryParams:
        """
        Query parameter. It is a multi-value mapping.
        """
        return QueryParams(self["QUERY_STRING"])

    @cached_property
    def headers(self) -> Headers:
        """
        A read-only case-independent mapping.

        Note that in its internal storage, all keys are in lower case.
        """
        return Headers(
            (key.lower().replace("_", "-"), value)
            for key, value in chain(
                (
                    (key[5:], value)
                    for key, value in self._environ.items()
                    if key.startswith("HTTP_")
                ),
                (
                    (key, value)
                    for key, value in self._environ.items()
                    if key in ("CONTENT_TYPE", "CONTENT_LENGTH")
                ),
            )
        )


class Request(HTTPConnection):
    @cached_property
    def method(self) -> str:
        """
        HTTP method. Uppercase string.
        """
        return self["REQUEST_METHOD"]

    def stream(self, chunk_size: int = 4096) -> Iterator[bytes]:
        """
        Streaming read request body. e.g. `for chunk in request.stream(): ...`

        If you access `.stream()` then the byte chunks are provided
        without storing the entire body to memory. Any subsequent
        calls to `.body`, `.form`, or `.json` will raise an error.
        """
        if "body" in self.__dict__:
            yield self.body
            return

        if self._stream_consumed:
            raise RuntimeError("Stream consumed")

        self._stream_consumed = True
        body = self._environ["wsgi.input"]
        while True:
            chunk = body.read(chunk_size)
            if not chunk:
                return
            yield chunk

    @cached_property
    def body(self) -> bytes:
        """
        Read all the contents of the request body into the memory and return it.
        """
        return b"".join([chunk for chunk in self.stream()])

    @cached_property
    def json(self) -> Any:
        """
        Call `self.body` and use `json.loads` parse it.

        If `content_type` is not equal to `application/json`,
        an HTTPExcption exception will be thrown.
        """
        if self.content_type == "application/json":
            return json.loads(
                self.body.decode(self.content_type.options.get("charset", "utf8"))
            )

        raise HTTPException(415, {"Accpet": "application/json"})

    @cached_property
    def form(self) -> FormData:
        """
        Parse the data in the form format and return it as a multi-value mapping.

        If `content_type` is equal to `multipart/form-data`, it will directly
        perform streaming analysis, and subsequent calls to `self.body`
        or `self.json` will raise errors.

        If `content_type` is not equal to `multipart/form-data` or
        `application/x-www-form-urlencoded`, an HTTPExcption exception will be thrown.
        """
        if self.content_type == "multipart/form-data":
            charset = self.content_type.options.get("charset", "utf8")
            parser = multipart.MultipartDecoder(
                self.content_type.options["boundary"].encode("latin-1"), charset
            )
            field_name = ""
            data = bytearray()
            file: Optional[UploadFile] = None

            items: List[Tuple[str, Union[str, UploadFile]]] = []

            for chunk in self.stream():
                parser.receive_data(chunk)
                while True:
                    event = parser.next_event()
                    if isinstance(event, (multipart.Epilogue, multipart.NeedData)):
                        break
                    elif isinstance(event, multipart.Field):
                        field_name = event.name
                    elif isinstance(event, multipart.File):
                        field_name = event.name
                        file = UploadFile(
                            event.filename, event.headers.get("content-type", "")
                        )
                    elif isinstance(event, multipart.Data):
                        if file is None:
                            data.extend(event.data)
                        else:
                            file.write(event.data)

                        if not event.more_data:
                            if file is None:
                                items.append(
                                    (field_name, multipart.safe_decode(data, charset))
                                )
                                data.clear()
                            else:
                                file.seek(0)
                                items.append((field_name, file))
                                file = None

            return FormData(items)
        if self.content_type == "application/x-www-form-urlencoded":
            body = self.body.decode(
                encoding=self.content_type.options.get("charset", "latin-1")
            )
            return FormData(parse_qsl(body, keep_blank_values=True))

        raise HTTPException(
            415, {"Accpet": "multipart/form-data, application/x-www-form-urlencoded"}
        )

    def close(self) -> None:
        """
        Close all temporary files in the `self.form`.

        This can always be called, regardless of whether you use form or not.
        """
        if "form" in self.__dict__:
            self.form.close()

Asynchronous example

import asyncio
import json
from enum import Enum
from typing import (
    Any,
    AsyncIterator,
    Dict,
    Iterator,
    List,
    Mapping,
    Optional,
    Tuple,
    Union,
)
from urllib.parse import parse_qsl

from baize import multipart
from baize.datastructures import (
    URL,
    Address,
    FormData,
    Headers,
    QueryParams,
    UploadFile,
)
from baize.exceptions import HTTPException
from baize.requests import MoreInfoFromHeaderMixin
from baize.typing import Message, Receive, Scope, Send
from baize.utils import cached_property

from .helper import empty_receive, empty_send


class ClientDisconnect(Exception):
    """
    HTTP connection disconnected.
    """


class HTTPConnection(Mapping[str, Any], MoreInfoFromHeaderMixin):
    """
    A base class for incoming HTTP connections.

    It is a valid Mapping type that allows you to directly
    access the values in any ASGI `scope` dictionary.
    """

    def __init__(
        self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
    ) -> None:
        self._scope = scope
        self._send = send
        self._receive = receive

    def __getitem__(self, key: str) -> Any:
        return self._scope[key]

    def __iter__(self) -> Iterator[str]:
        return iter(self._scope)

    def __len__(self) -> int:
        return len(self._scope)

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, self.__class__):
            return NotImplemented
        return (
            self._scope == other._scope
            and self._send == other._send
            and self._receive == other._receive
        )

    @cached_property
    def client(self) -> Address:
        """
        Client's IP and Port.

        Note that this depends on the "client" value given by
        the ASGI Server, and is not necessarily accurate.
        """
        host, port = self.get("client") or (None, None)
        return Address(host=host, port=port)

    @cached_property
    def url(self) -> URL:
        """
        The full URL of this request.
        """
        return URL(scope=self._scope)

    @cached_property
    def path_params(self) -> Dict[str, Any]:
        """
        The path parameters parsed by the framework.
        """
        return self.get("path_params", {})

    @cached_property
    def query_params(self) -> QueryParams:
        """
        Query parameter. It is a multi-value mapping.
        """
        return QueryParams(self["query_string"])

    @cached_property
    def headers(self) -> Headers:
        """
        A read-only case-independent mapping.

        Note that in its internal storage, all keys are in lower case.
        """
        return Headers(
            (key.decode("latin-1"), value.decode("latin-1"))
            for key, value in self._scope["headers"]
        )


class Request(HTTPConnection):
    def __init__(
        self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
    ) -> None:
        assert scope["type"] == "http"
        super().__init__(scope, receive, send)
        self._stream_consumed = False
        self._is_disconnected = False

    @cached_property
    def method(self) -> str:
        """
        HTTP method. Uppercase string.
        """
        return self._scope["method"]

    async def stream(self) -> AsyncIterator[bytes]:
        """
        Streaming read request body. e.g. `async for chunk in request.stream(): ...`

        If you access `.stream()` then the byte chunks are provided
        without storing the entire body to memory. Any subsequent
        calls to `.body`, `.form`, or `.json` will raise an error.
        """
        if "body" in self.__dict__ and self.__dict__["body"].done():
            yield await self.body
            yield b""
            return

        if self._stream_consumed:
            raise RuntimeError("Stream consumed")

        self._stream_consumed = True
        while True:
            message = await self._receive()
            if message["type"] == "http.request":
                body = message.get("body", b"")
                if body:
                    yield body
                if not message.get("more_body", False):
                    break
            elif message["type"] == "http.disconnect":
                self._is_disconnected = True
                raise ClientDisconnect()
        yield b""

    @cached_property
    async def body(self) -> bytes:
        """
        Read all the contents of the request body into the memory and return it.
        """
        return b"".join([chunk async for chunk in self.stream()])

    @cached_property
    async def json(self) -> Any:
        """
        Call `await self.body` and use `json.loads` parse it.

        If `content_type` is not equal to `application/json`,
        an HTTPExcption exception will be thrown.
        """
        if self.content_type == "application/json":
            data = await self.body
            return json.loads(
                data.decode(self.content_type.options.get("charset", "utf8"))
            )

        raise HTTPException(415, {"Accpet": "application/json"})

    @cached_property
    async def form(self) -> FormData:
        """
        Parse the data in the form format and return it as a multi-value mapping.

        If `content_type` is equal to `multipart/form-data`, it will directly
        perform streaming analysis, and subsequent calls to `self.body`
        or `self.json` will raise errors.

        If `content_type` is not equal to `multipart/form-data` or
        `application/x-www-form-urlencoded`, an HTTPExcption exception will be thrown.
        """
        if self.content_type == "multipart/form-data":
            charset = self.content_type.options.get("charset", "utf8")
            parser = multipart.MultipartDecoder(
                self.content_type.options["boundary"].encode("latin-1"), charset
            )
            field_name = ""
            data = bytearray()
            file: Optional[UploadFile] = None

            items: List[Tuple[str, Union[str, UploadFile]]] = []

            async for chunk in self.stream():
                parser.receive_data(chunk)
                while True:
                    event = parser.next_event()
                    if isinstance(event, (multipart.Epilogue, multipart.NeedData)):
                        break
                    elif isinstance(event, multipart.Field):
                        field_name = event.name
                    elif isinstance(event, multipart.File):
                        field_name = event.name
                        file = UploadFile(
                            event.filename, event.headers.get("content-type", "")
                        )
                    elif isinstance(event, multipart.Data):
                        if file is None:
                            data.extend(event.data)
                        else:
                            await file.awrite(event.data)

                        if not event.more_data:
                            if file is None:
                                items.append(
                                    (field_name, multipart.safe_decode(data, charset))
                                )
                                data.clear()
                            else:
                                await file.aseek(0)
                                items.append((field_name, file))
                                file = None

            return FormData(items)
        if self.content_type == "application/x-www-form-urlencoded":
            body = (await self.body).decode(
                encoding=self.content_type.options.get("charset", "latin-1")
            )
            return FormData(parse_qsl(body, keep_blank_values=True))

        raise HTTPException(
            415, {"Accpet": "multipart/form-data, application/x-www-form-urlencoded"}
        )

    async def close(self) -> None:
        """
        Close all temporary files in the `self.form`.

        This can always be called, regardless of whether you use form or not.
        """
        if "form" in self.__dict__ and self.__dict__["form"].done():
            await (await self.form).aclose()

    async def is_disconnected(self) -> bool:
        """
        The method used to determine whether the connection is interrupted.
        """
        if not self._is_disconnected:
            try:
                message = await asyncio.wait_for(self._receive(), timeout=0.0000001)
                self._is_disconnected = message.get("type") == "http.disconnect"
            except asyncio.TimeoutError:
                pass
        return self._is_disconnected


class WebSocketDisconnect(Exception):
    def __init__(self, code: int = 1000) -> None:
        self.code = code


class WebSocketState(Enum):
    CONNECTING = 0
    CONNECTED = 1
    DISCONNECTED = 2


class WebSocket(HTTPConnection):
    def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
        assert scope["type"] == "websocket"
        super().__init__(scope, receive, send)
        self.client_state = WebSocketState.CONNECTING
        self.application_state = WebSocketState.CONNECTING

    async def receive(self) -> Message:
        """
        Receive ASGI websocket messages, ensuring valid state transitions.
        """
        if self.client_state == WebSocketState.CONNECTING:
            message = await self._receive()
            message_type = message["type"]
            assert message_type == "websocket.connect"
            self.client_state = WebSocketState.CONNECTED
            return message
        elif self.client_state == WebSocketState.CONNECTED:
            message = await self._receive()
            message_type = message["type"]
            assert message_type in {"websocket.receive", "websocket.disconnect"}
            if message_type == "websocket.disconnect":
                self.client_state = WebSocketState.DISCONNECTED
            return message
        else:
            raise RuntimeError(
                'Cannot call "receive" once a disconnect message has been received.'
            )

    async def send(self, message: Message) -> None:
        """
        Send ASGI websocket messages, ensuring valid state transitions.
        """
        if self.application_state == WebSocketState.CONNECTING:
            message_type = message["type"]
            assert message_type in {"websocket.accept", "websocket.close"}
            if message_type == "websocket.close":
                self.application_state = WebSocketState.DISCONNECTED
            else:
                self.application_state = WebSocketState.CONNECTED
            await self._send(message)
        elif self.application_state == WebSocketState.CONNECTED:
            message_type = message["type"]
            assert message_type in {"websocket.send", "websocket.close"}
            if message_type == "websocket.close":
                self.application_state = WebSocketState.DISCONNECTED
            await self._send(message)
        else:
            raise RuntimeError('Cannot call "send" once a close message has been sent.')

    async def accept(self, subprotocol: str = None) -> None:
        """
        Accept websocket connection.
        """
        if self.client_state == WebSocketState.CONNECTING:
            # If we haven't yet seen the 'connect' message, then wait for it first.
            await self.receive()
        await self.send({"type": "websocket.accept", "subprotocol": subprotocol})

    def _raise_on_disconnect(self, message: Message) -> None:
        if message["type"] == "websocket.disconnect":
            raise WebSocketDisconnect(message["code"])

    async def receive_text(self) -> str:
        """
        Receive a WebSocket text frame and return.
        """
        assert self.application_state == WebSocketState.CONNECTED
        message = await self.receive()
        self._raise_on_disconnect(message)
        return message["text"]

    async def receive_bytes(self) -> bytes:
        """
        Receive a WebSocket binary frame and return.
        """
        assert self.application_state == WebSocketState.CONNECTED
        message = await self.receive()
        self._raise_on_disconnect(message)
        return message["bytes"]

    async def iter_text(self) -> AsyncIterator[str]:
        """
        Keep receiving text frames until the WebSocket connection is disconnected.
        """
        try:
            while True:
                yield await self.receive_text()
        except WebSocketDisconnect:
            pass

    async def iter_bytes(self) -> AsyncIterator[bytes]:
        """
        Keep receiving binary frames until the WebSocket connection is disconnected.
        """
        try:
            while True:
                yield await self.receive_bytes()
        except WebSocketDisconnect:
            pass

    async def send_text(self, data: str) -> None:
        """
        Send a WebSocket text frame.
        """
        await self.send({"type": "websocket.send", "text": data})

    async def send_bytes(self, data: bytes) -> None:
        """
        Send a WebSocket binary frame.
        """
        await self.send({"type": "websocket.send", "bytes": data})

    async def close(self, code: int = 1000) -> None:
        """
        Close WebSocket connection. It can be called multiple times.
        """
        if self.application_state != WebSocketState.DISCONNECTED:
            await self.send({"type": "websocket.close", "code": code})