|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196 |
- from __future__ import annotations
-
- import asyncio
- import email.utils
- import functools
- import http
- import inspect
- import logging
- import socket
- import warnings
- from types import TracebackType
- from typing import (
- Any,
- Awaitable,
- Callable,
- Generator,
- Iterable,
- List,
- Optional,
- Sequence,
- Set,
- Tuple,
- Type,
- Union,
- cast,
- )
-
- from ..datastructures import Headers, HeadersLike, MultipleValuesError
- from ..exceptions import (
- AbortHandshake,
- InvalidHandshake,
- InvalidHeader,
- InvalidMessage,
- InvalidOrigin,
- InvalidUpgrade,
- NegotiationError,
- )
- from ..extensions import Extension, ServerExtensionFactory
- from ..extensions.permessage_deflate import enable_server_permessage_deflate
- from ..headers import (
- build_extension,
- parse_extension,
- parse_subprotocol,
- validate_subprotocols,
- )
- from ..http import USER_AGENT
- from ..protocol import State
- from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol
- from .compatibility import asyncio_timeout, loop_if_py_lt_38
- from .handshake import build_response, check_request
- from .http import read_request
- from .protocol import WebSocketCommonProtocol
-
-
- __all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"]
-
-
- HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]]
-
- HTTPResponse = Tuple[http.HTTPStatus, HeadersLike, bytes]
-
-
- class WebSocketServerProtocol(WebSocketCommonProtocol):
- """
- WebSocket server connection.
-
- :class:`WebSocketServerProtocol` provides :meth:`recv` and :meth:`send`
- coroutines for receiving and sending messages.
-
- It supports asynchronous iteration to receive messages::
-
- async for message in websocket:
- await process(message)
-
- The iterator exits normally when the connection is closed with close code
- 1000 (OK) or 1001 (going away) or without a close code. It raises
- a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection
- is closed with any other code.
-
- You may customize the opening handshake in a subclass by
- overriding :meth:`process_request` or :meth:`select_subprotocol`.
-
- Args:
- ws_server: WebSocket server that created this connection.
-
- See :func:`serve` for the documentation of ``ws_handler``, ``logger``, ``origins``,
- ``extensions``, ``subprotocols``, ``extra_headers``, and ``server_header``.
-
- See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
- documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
- ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
-
- """
-
- is_client = False
- side = "server"
-
- def __init__(
- self,
- ws_handler: Union[
- Callable[[WebSocketServerProtocol], Awaitable[Any]],
- Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated
- ],
- ws_server: WebSocketServer,
- *,
- logger: Optional[LoggerLike] = None,
- origins: Optional[Sequence[Optional[Origin]]] = None,
- extensions: Optional[Sequence[ServerExtensionFactory]] = None,
- subprotocols: Optional[Sequence[Subprotocol]] = None,
- extra_headers: Optional[HeadersLikeOrCallable] = None,
- server_header: Optional[str] = USER_AGENT,
- process_request: Optional[
- Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]]
- ] = None,
- select_subprotocol: Optional[
- Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol]
- ] = None,
- open_timeout: Optional[float] = 10,
- **kwargs: Any,
- ) -> None:
- if logger is None:
- logger = logging.getLogger("websockets.server")
- super().__init__(logger=logger, **kwargs)
- # For backwards compatibility with 6.0 or earlier.
- if origins is not None and "" in origins:
- warnings.warn("use None instead of '' in origins", DeprecationWarning)
- origins = [None if origin == "" else origin for origin in origins]
- # For backwards compatibility with 10.0 or earlier. Done here in
- # addition to serve to trigger the deprecation warning on direct
- # use of WebSocketServerProtocol.
- self.ws_handler = remove_path_argument(ws_handler)
- self.ws_server = ws_server
- self.origins = origins
- self.available_extensions = extensions
- self.available_subprotocols = subprotocols
- self.extra_headers = extra_headers
- self.server_header = server_header
- self._process_request = process_request
- self._select_subprotocol = select_subprotocol
- self.open_timeout = open_timeout
-
- def connection_made(self, transport: asyncio.BaseTransport) -> None:
- """
- Register connection and initialize a task to handle it.
-
- """
- super().connection_made(transport)
- # Register the connection with the server before creating the handler
- # task. Registering at the beginning of the handler coroutine would
- # create a race condition between the creation of the task, which
- # schedules its execution, and the moment the handler starts running.
- self.ws_server.register(self)
- self.handler_task = self.loop.create_task(self.handler())
-
- async def handler(self) -> None:
- """
- Handle the lifecycle of a WebSocket connection.
-
- Since this method doesn't have a caller able to handle exceptions, it
- attempts to log relevant ones and guarantees that the TCP connection is
- closed before exiting.
-
- """
- try:
- try:
- async with asyncio_timeout(self.open_timeout):
- await self.handshake(
- origins=self.origins,
- available_extensions=self.available_extensions,
- available_subprotocols=self.available_subprotocols,
- extra_headers=self.extra_headers,
- )
- # Remove this branch when dropping support for Python < 3.8
- # because CancelledError no longer inherits Exception.
- except asyncio.CancelledError: # pragma: no cover
- raise
- except asyncio.TimeoutError: # pragma: no cover
- raise
- except ConnectionError:
- raise
- except Exception as exc:
- if isinstance(exc, AbortHandshake):
- status, headers, body = exc.status, exc.headers, exc.body
- elif isinstance(exc, InvalidOrigin):
- if self.debug:
- self.logger.debug("! invalid origin", exc_info=True)
- status, headers, body = (
- http.HTTPStatus.FORBIDDEN,
- Headers(),
- f"Failed to open a WebSocket connection: {exc}.\n".encode(),
- )
- elif isinstance(exc, InvalidUpgrade):
- if self.debug:
- self.logger.debug("! invalid upgrade", exc_info=True)
- status, headers, body = (
- http.HTTPStatus.UPGRADE_REQUIRED,
- Headers([("Upgrade", "websocket")]),
- (
- f"Failed to open a WebSocket connection: {exc}.\n"
- f"\n"
- f"You cannot access a WebSocket server directly "
- f"with a browser. You need a WebSocket client.\n"
- ).encode(),
- )
- elif isinstance(exc, InvalidHandshake):
- if self.debug:
- self.logger.debug("! invalid handshake", exc_info=True)
- status, headers, body = (
- http.HTTPStatus.BAD_REQUEST,
- Headers(),
- f"Failed to open a WebSocket connection: {exc}.\n".encode(),
- )
- else:
- self.logger.error("opening handshake failed", exc_info=True)
- status, headers, body = (
- http.HTTPStatus.INTERNAL_SERVER_ERROR,
- Headers(),
- (
- b"Failed to open a WebSocket connection.\n"
- b"See server log for more information.\n"
- ),
- )
-
- headers.setdefault("Date", email.utils.formatdate(usegmt=True))
- if self.server_header is not None:
- headers.setdefault("Server", self.server_header)
-
- headers.setdefault("Content-Length", str(len(body)))
- headers.setdefault("Content-Type", "text/plain")
- headers.setdefault("Connection", "close")
-
- self.write_http_response(status, headers, body)
- self.logger.info(
- "connection failed (%d %s)", status.value, status.phrase
- )
- await self.close_transport()
- return
-
- try:
- await self.ws_handler(self)
- except Exception:
- self.logger.error("connection handler failed", exc_info=True)
- if not self.closed:
- self.fail_connection(1011)
- raise
-
- try:
- await self.close()
- except ConnectionError:
- raise
- except Exception:
- self.logger.error("closing handshake failed", exc_info=True)
- raise
-
- except Exception:
- # Last-ditch attempt to avoid leaking connections on errors.
- try:
- self.transport.close()
- except Exception: # pragma: no cover
- pass
-
- finally:
- # Unregister the connection with the server when the handler task
- # terminates. Registration is tied to the lifecycle of the handler
- # task because the server waits for tasks attached to registered
- # connections before terminating.
- self.ws_server.unregister(self)
- self.logger.info("connection closed")
-
- async def read_http_request(self) -> Tuple[str, Headers]:
- """
- Read request line and headers from the HTTP request.
-
- If the request contains a body, it may be read from ``self.reader``
- after this coroutine returns.
-
- Raises:
- InvalidMessage: if the HTTP message is malformed or isn't an
- HTTP/1.1 GET request.
-
- """
- try:
- path, headers = await read_request(self.reader)
- except asyncio.CancelledError: # pragma: no cover
- raise
- except Exception as exc:
- raise InvalidMessage("did not receive a valid HTTP request") from exc
-
- if self.debug:
- self.logger.debug("< GET %s HTTP/1.1", path)
- for key, value in headers.raw_items():
- self.logger.debug("< %s: %s", key, value)
-
- self.path = path
- self.request_headers = headers
-
- return path, headers
-
- def write_http_response(
- self, status: http.HTTPStatus, headers: Headers, body: Optional[bytes] = None
- ) -> None:
- """
- Write status line and headers to the HTTP response.
-
- This coroutine is also able to write a response body.
-
- """
- self.response_headers = headers
-
- if self.debug:
- self.logger.debug("> HTTP/1.1 %d %s", status.value, status.phrase)
- for key, value in headers.raw_items():
- self.logger.debug("> %s: %s", key, value)
- if body is not None:
- self.logger.debug("> [body] (%d bytes)", len(body))
-
- # Since the status line and headers only contain ASCII characters,
- # we can keep this simple.
- response = f"HTTP/1.1 {status.value} {status.phrase}\r\n"
- response += str(headers)
-
- self.transport.write(response.encode())
-
- if body is not None:
- self.transport.write(body)
-
- async def process_request(
- self, path: str, request_headers: Headers
- ) -> Optional[HTTPResponse]:
- """
- Intercept the HTTP request and return an HTTP response if appropriate.
-
- You may override this method in a :class:`WebSocketServerProtocol`
- subclass, for example:
-
- * to return an HTTP 200 OK response on a given path; then a load
- balancer can use this path for a health check;
- * to authenticate the request and return an HTTP 401 Unauthorized or an
- HTTP 403 Forbidden when authentication fails.
-
- You may also override this method with the ``process_request``
- argument of :func:`serve` and :class:`WebSocketServerProtocol`. This
- is equivalent, except ``process_request`` won't have access to the
- protocol instance, so it can't store information for later use.
-
- :meth:`process_request` is expected to complete quickly. If it may run
- for a long time, then it should await :meth:`wait_closed` and exit if
- :meth:`wait_closed` completes, or else it could prevent the server
- from shutting down.
-
- Args:
- path: request path, including optional query string.
- request_headers: request headers.
-
- Returns:
- Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]: :obj:`None`
- to continue the WebSocket handshake normally.
-
- An HTTP response, represented by a 3-uple of the response status,
- headers, and body, to abort the WebSocket handshake and return
- that HTTP response instead.
-
- """
- if self._process_request is not None:
- response = self._process_request(path, request_headers)
- if isinstance(response, Awaitable):
- return await response
- else:
- # For backwards compatibility with 7.0.
- warnings.warn(
- "declare process_request as a coroutine", DeprecationWarning
- )
- return response
- return None
-
- @staticmethod
- def process_origin(
- headers: Headers, origins: Optional[Sequence[Optional[Origin]]] = None
- ) -> Optional[Origin]:
- """
- Handle the Origin HTTP request header.
-
- Args:
- headers: request headers.
- origins: optional list of acceptable origins.
-
- Raises:
- InvalidOrigin: if the origin isn't acceptable.
-
- """
- # "The user agent MUST NOT include more than one Origin header field"
- # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3.
- try:
- origin = cast(Optional[Origin], headers.get("Origin"))
- except MultipleValuesError as exc:
- raise InvalidHeader("Origin", "more than one Origin header found") from exc
- if origins is not None:
- if origin not in origins:
- raise InvalidOrigin(origin)
- return origin
-
- @staticmethod
- def process_extensions(
- headers: Headers,
- available_extensions: Optional[Sequence[ServerExtensionFactory]],
- ) -> Tuple[Optional[str], List[Extension]]:
- """
- Handle the Sec-WebSocket-Extensions HTTP request header.
-
- Accept or reject each extension proposed in the client request.
- Negotiate parameters for accepted extensions.
-
- Return the Sec-WebSocket-Extensions HTTP response header and the list
- of accepted extensions.
-
- :rfc:`6455` leaves the rules up to the specification of each
- :extension.
-
- To provide this level of flexibility, for each extension proposed by
- the client, we check for a match with each extension available in the
- server configuration. If no match is found, the extension is ignored.
-
- If several variants of the same extension are proposed by the client,
- it may be accepted several times, which won't make sense in general.
- Extensions must implement their own requirements. For this purpose,
- the list of previously accepted extensions is provided.
-
- This process doesn't allow the server to reorder extensions. It can
- only select a subset of the extensions proposed by the client.
-
- Other requirements, for example related to mandatory extensions or the
- order of extensions, may be implemented by overriding this method.
-
- Args:
- headers: request headers.
- extensions: optional list of supported extensions.
-
- Raises:
- InvalidHandshake: to abort the handshake with an HTTP 400 error.
-
- """
- response_header_value: Optional[str] = None
-
- extension_headers: List[ExtensionHeader] = []
- accepted_extensions: List[Extension] = []
-
- header_values = headers.get_all("Sec-WebSocket-Extensions")
-
- if header_values and available_extensions:
- parsed_header_values: List[ExtensionHeader] = sum(
- [parse_extension(header_value) for header_value in header_values], []
- )
-
- for name, request_params in parsed_header_values:
- for ext_factory in available_extensions:
- # Skip non-matching extensions based on their name.
- if ext_factory.name != name:
- continue
-
- # Skip non-matching extensions based on their params.
- try:
- response_params, extension = ext_factory.process_request_params(
- request_params, accepted_extensions
- )
- except NegotiationError:
- continue
-
- # Add matching extension to the final list.
- extension_headers.append((name, response_params))
- accepted_extensions.append(extension)
-
- # Break out of the loop once we have a match.
- break
-
- # If we didn't break from the loop, no extension in our list
- # matched what the client sent. The extension is declined.
-
- # Serialize extension header.
- if extension_headers:
- response_header_value = build_extension(extension_headers)
-
- return response_header_value, accepted_extensions
-
- # Not @staticmethod because it calls self.select_subprotocol()
- def process_subprotocol(
- self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
- ) -> Optional[Subprotocol]:
- """
- Handle the Sec-WebSocket-Protocol HTTP request header.
-
- Return Sec-WebSocket-Protocol HTTP response header, which is the same
- as the selected subprotocol.
-
- Args:
- headers: request headers.
- available_subprotocols: optional list of supported subprotocols.
-
- Raises:
- InvalidHandshake: to abort the handshake with an HTTP 400 error.
-
- """
- subprotocol: Optional[Subprotocol] = None
-
- header_values = headers.get_all("Sec-WebSocket-Protocol")
-
- if header_values and available_subprotocols:
- parsed_header_values: List[Subprotocol] = sum(
- [parse_subprotocol(header_value) for header_value in header_values], []
- )
-
- subprotocol = self.select_subprotocol(
- parsed_header_values, available_subprotocols
- )
-
- return subprotocol
-
- def select_subprotocol(
- self,
- client_subprotocols: Sequence[Subprotocol],
- server_subprotocols: Sequence[Subprotocol],
- ) -> Optional[Subprotocol]:
- """
- Pick a subprotocol among those supported by the client and the server.
-
- If several subprotocols are available, select the preferred subprotocol
- by giving equal weight to the preferences of the client and the server.
-
- If no subprotocol is available, proceed without a subprotocol.
-
- You may provide a ``select_subprotocol`` argument to :func:`serve` or
- :class:`WebSocketServerProtocol` to override this logic. For example,
- you could reject the handshake if the client doesn't support a
- particular subprotocol, rather than accept the handshake without that
- subprotocol.
-
- Args:
- client_subprotocols: list of subprotocols offered by the client.
- server_subprotocols: list of subprotocols available on the server.
-
- Returns:
- Optional[Subprotocol]: Selected subprotocol, if a common subprotocol
- was found.
-
- :obj:`None` to continue without a subprotocol.
-
- """
- if self._select_subprotocol is not None:
- return self._select_subprotocol(client_subprotocols, server_subprotocols)
-
- subprotocols = set(client_subprotocols) & set(server_subprotocols)
- if not subprotocols:
- return None
- return sorted(
- subprotocols,
- key=lambda p: client_subprotocols.index(p) + server_subprotocols.index(p),
- )[0]
-
- async def handshake(
- self,
- origins: Optional[Sequence[Optional[Origin]]] = None,
- available_extensions: Optional[Sequence[ServerExtensionFactory]] = None,
- available_subprotocols: Optional[Sequence[Subprotocol]] = None,
- extra_headers: Optional[HeadersLikeOrCallable] = None,
- ) -> str:
- """
- Perform the server side of the opening handshake.
-
- Args:
- origins: list of acceptable values of the Origin HTTP header;
- include :obj:`None` if the lack of an origin is acceptable.
- extensions: list of supported extensions, in order in which they
- should be tried.
- subprotocols: list of supported subprotocols, in order of
- decreasing preference.
- extra_headers: arbitrary HTTP headers to add to the response when
- the handshake succeeds.
-
- Returns:
- str: path of the URI of the request.
-
- Raises:
- InvalidHandshake: if the handshake fails.
-
- """
- path, request_headers = await self.read_http_request()
-
- # Hook for customizing request handling, for example checking
- # authentication or treating some paths as plain HTTP endpoints.
- early_response_awaitable = self.process_request(path, request_headers)
- if isinstance(early_response_awaitable, Awaitable):
- early_response = await early_response_awaitable
- else:
- # For backwards compatibility with 7.0.
- warnings.warn("declare process_request as a coroutine", DeprecationWarning)
- early_response = early_response_awaitable
-
- # The connection may drop while process_request is running.
- if self.state is State.CLOSED:
- # This subclass of ConnectionError is silently ignored in handler().
- raise BrokenPipeError("connection closed during opening handshake")
-
- # Change the response to a 503 error if the server is shutting down.
- if not self.ws_server.is_serving():
- early_response = (
- http.HTTPStatus.SERVICE_UNAVAILABLE,
- [],
- b"Server is shutting down.\n",
- )
-
- if early_response is not None:
- raise AbortHandshake(*early_response)
-
- key = check_request(request_headers)
-
- self.origin = self.process_origin(request_headers, origins)
-
- extensions_header, self.extensions = self.process_extensions(
- request_headers, available_extensions
- )
-
- protocol_header = self.subprotocol = self.process_subprotocol(
- request_headers, available_subprotocols
- )
-
- response_headers = Headers()
-
- build_response(response_headers, key)
-
- if extensions_header is not None:
- response_headers["Sec-WebSocket-Extensions"] = extensions_header
-
- if protocol_header is not None:
- response_headers["Sec-WebSocket-Protocol"] = protocol_header
-
- if callable(extra_headers):
- extra_headers = extra_headers(path, self.request_headers)
- if extra_headers is not None:
- response_headers.update(extra_headers)
-
- response_headers.setdefault("Date", email.utils.formatdate(usegmt=True))
- if self.server_header is not None:
- response_headers.setdefault("Server", self.server_header)
-
- self.write_http_response(http.HTTPStatus.SWITCHING_PROTOCOLS, response_headers)
-
- self.logger.info("connection open")
-
- self.connection_open()
-
- return path
-
-
- class WebSocketServer:
- """
- WebSocket server returned by :func:`serve`.
-
- This class provides the same interface as :class:`~asyncio.Server`,
- notably the :meth:`~asyncio.Server.close`
- and :meth:`~asyncio.Server.wait_closed` methods.
-
- It keeps track of WebSocket connections in order to close them properly
- when shutting down.
-
- Args:
- logger: Logger for this server.
- It defaults to ``logging.getLogger("websockets.server")``.
- See the :doc:`logging guide <../../topics/logging>` for details.
-
- """
-
- def __init__(self, logger: Optional[LoggerLike] = None):
- if logger is None:
- logger = logging.getLogger("websockets.server")
- self.logger = logger
-
- # Keep track of active connections.
- self.websockets: Set[WebSocketServerProtocol] = set()
-
- # Task responsible for closing the server and terminating connections.
- self.close_task: Optional[asyncio.Task[None]] = None
-
- # Completed when the server is closed and connections are terminated.
- self.closed_waiter: asyncio.Future[None]
-
- def wrap(self, server: asyncio.base_events.Server) -> None:
- """
- Attach to a given :class:`~asyncio.Server`.
-
- Since :meth:`~asyncio.loop.create_server` doesn't support injecting a
- custom ``Server`` class, the easiest solution that doesn't rely on
- private :mod:`asyncio` APIs is to:
-
- - instantiate a :class:`WebSocketServer`
- - give the protocol factory a reference to that instance
- - call :meth:`~asyncio.loop.create_server` with the factory
- - attach the resulting :class:`~asyncio.Server` with this method
-
- """
- self.server = server
- for sock in server.sockets:
- if sock.family == socket.AF_INET:
- name = "%s:%d" % sock.getsockname()
- elif sock.family == socket.AF_INET6:
- name = "[%s]:%d" % sock.getsockname()[:2]
- elif sock.family == socket.AF_UNIX:
- name = sock.getsockname()
- # In the unlikely event that someone runs websockets over a
- # protocol other than IP or Unix sockets, avoid crashing.
- else: # pragma: no cover
- name = str(sock.getsockname())
- self.logger.info("server listening on %s", name)
-
- # Initialized here because we need a reference to the event loop.
- # This should be moved back to __init__ when dropping Python < 3.10.
- self.closed_waiter = server.get_loop().create_future()
-
- def register(self, protocol: WebSocketServerProtocol) -> None:
- """
- Register a connection with this server.
-
- """
- self.websockets.add(protocol)
-
- def unregister(self, protocol: WebSocketServerProtocol) -> None:
- """
- Unregister a connection with this server.
-
- """
- self.websockets.remove(protocol)
-
- def close(self, close_connections: bool = True) -> None:
- """
- Close the server.
-
- * Close the underlying :class:`~asyncio.Server`.
- * When ``close_connections`` is :obj:`True`, which is the default,
- close existing connections. Specifically:
-
- * Reject opening WebSocket connections with an HTTP 503 (service
- unavailable) error. This happens when the server accepted the TCP
- connection but didn't complete the opening handshake before closing.
- * Close open WebSocket connections with close code 1001 (going away).
-
- * Wait until all connection handlers terminate.
-
- :meth:`close` is idempotent.
-
- """
- if self.close_task is None:
- self.close_task = self.get_loop().create_task(
- self._close(close_connections)
- )
-
- async def _close(self, close_connections: bool) -> None:
- """
- Implementation of :meth:`close`.
-
- This calls :meth:`~asyncio.Server.close` on the underlying
- :class:`~asyncio.Server` object to stop accepting new connections and
- then closes open connections with close code 1001.
-
- """
- self.logger.info("server closing")
-
- # Stop accepting new connections.
- self.server.close()
-
- # Wait until self.server.close() completes.
- await self.server.wait_closed()
-
- # Wait until all accepted connections reach connection_made() and call
- # register(). See https://bugs.python.org/issue34852 for details.
- await asyncio.sleep(0, **loop_if_py_lt_38(self.get_loop()))
-
- if close_connections:
- # Close OPEN connections with status code 1001. Since the server was
- # closed, handshake() closes OPENING connections with an HTTP 503
- # error. Wait until all connections are closed.
-
- close_tasks = [
- asyncio.create_task(websocket.close(1001))
- for websocket in self.websockets
- if websocket.state is not State.CONNECTING
- ]
- # asyncio.wait doesn't accept an empty first argument.
- if close_tasks:
- await asyncio.wait(
- close_tasks,
- **loop_if_py_lt_38(self.get_loop()),
- )
-
- # Wait until all connection handlers are complete.
-
- # asyncio.wait doesn't accept an empty first argument.
- if self.websockets:
- await asyncio.wait(
- [websocket.handler_task for websocket in self.websockets],
- **loop_if_py_lt_38(self.get_loop()),
- )
-
- # Tell wait_closed() to return.
- self.closed_waiter.set_result(None)
-
- self.logger.info("server closed")
-
- async def wait_closed(self) -> None:
- """
- Wait until the server is closed.
-
- When :meth:`wait_closed` returns, all TCP connections are closed and
- all connection handlers have returned.
-
- To ensure a fast shutdown, a connection handler should always be
- awaiting at least one of:
-
- * :meth:`~WebSocketServerProtocol.recv`: when the connection is closed,
- it raises :exc:`~websockets.exceptions.ConnectionClosedOK`;
- * :meth:`~WebSocketServerProtocol.wait_closed`: when the connection is
- closed, it returns.
-
- Then the connection handler is immediately notified of the shutdown;
- it can clean up and exit.
-
- """
- await asyncio.shield(self.closed_waiter)
-
- def get_loop(self) -> asyncio.AbstractEventLoop:
- """
- See :meth:`asyncio.Server.get_loop`.
-
- """
- return self.server.get_loop()
-
- def is_serving(self) -> bool:
- """
- See :meth:`asyncio.Server.is_serving`.
-
- """
- return self.server.is_serving()
-
- async def start_serving(self) -> None: # pragma: no cover
- """
- See :meth:`asyncio.Server.start_serving`.
-
- Typical use::
-
- server = await serve(..., start_serving=False)
- # perform additional setup here...
- # ... then start the server
- await server.start_serving()
-
- """
- await self.server.start_serving()
-
- async def serve_forever(self) -> None: # pragma: no cover
- """
- See :meth:`asyncio.Server.serve_forever`.
-
- Typical use::
-
- server = await serve(...)
- # this coroutine doesn't return
- # canceling it stops the server
- await server.serve_forever()
-
- This is an alternative to using :func:`serve` as an asynchronous context
- manager. Shutdown is triggered by canceling :meth:`serve_forever`
- instead of exiting a :func:`serve` context.
-
- """
- await self.server.serve_forever()
-
- @property
- def sockets(self) -> Iterable[socket.socket]:
- """
- See :attr:`asyncio.Server.sockets`.
-
- """
- return self.server.sockets
-
- async def __aenter__(self) -> WebSocketServer: # pragma: no cover
- return self
-
- async def __aexit__(
- self,
- exc_type: Optional[Type[BaseException]],
- exc_value: Optional[BaseException],
- traceback: Optional[TracebackType],
- ) -> None: # pragma: no cover
- self.close()
- await self.wait_closed()
-
-
- class Serve:
- """
- Start a WebSocket server listening on ``host`` and ``port``.
-
- Whenever a client connects, the server creates a
- :class:`WebSocketServerProtocol`, performs the opening handshake, and
- delegates to the connection handler, ``ws_handler``.
-
- The handler receives the :class:`WebSocketServerProtocol` and uses it to
- send and receive messages.
-
- Once the handler completes, either normally or with an exception, the
- server performs the closing handshake and closes the connection.
-
- Awaiting :func:`serve` yields a :class:`WebSocketServer`. This object
- provides a :meth:`~WebSocketServer.close` method to shut down the server::
-
- stop = asyncio.Future() # set this future to exit the server
-
- server = await serve(...)
- await stop
- await server.close()
-
- :func:`serve` can be used as an asynchronous context manager. Then, the
- server is shut down automatically when exiting the context::
-
- stop = asyncio.Future() # set this future to exit the server
-
- async with serve(...):
- await stop
-
- Args:
- ws_handler: Connection handler. It receives the WebSocket connection,
- which is a :class:`WebSocketServerProtocol`, in argument.
- host: Network interfaces the server binds to.
- See :meth:`~asyncio.loop.create_server` for details.
- port: TCP port the server listens on.
- See :meth:`~asyncio.loop.create_server` for details.
- create_protocol: Factory for the :class:`asyncio.Protocol` managing
- the connection. It defaults to :class:`WebSocketServerProtocol`.
- Set it to a wrapper or a subclass to customize connection handling.
- logger: Logger for this server.
- It defaults to ``logging.getLogger("websockets.server")``.
- See the :doc:`logging guide <../../topics/logging>` for details.
- compression: The "permessage-deflate" extension is enabled by default.
- Set ``compression`` to :obj:`None` to disable it. See the
- :doc:`compression guide <../../topics/compression>` for details.
- origins: Acceptable values of the ``Origin`` header, for defending
- against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
- in the list if the lack of an origin is acceptable.
- extensions: List of supported extensions, in order in which they
- should be negotiated and run.
- subprotocols: List of supported subprotocols, in order of decreasing
- preference.
- extra_headers (Union[HeadersLike, Callable[[str, Headers], HeadersLike]]):
- Arbitrary HTTP headers to add to the response. This can be
- a :data:`~websockets.datastructures.HeadersLike` or a callable
- taking the request path and headers in arguments and returning
- a :data:`~websockets.datastructures.HeadersLike`.
- server_header: Value of the ``Server`` response header.
- It defaults to ``"Python/x.y.z websockets/X.Y"``.
- Setting it to :obj:`None` removes the header.
- process_request (Optional[Callable[[str, Headers], \
- Awaitable[Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]]]]):
- Intercept HTTP request before the opening handshake.
- See :meth:`~WebSocketServerProtocol.process_request` for details.
- select_subprotocol: Select a subprotocol supported by the client.
- See :meth:`~WebSocketServerProtocol.select_subprotocol` for details.
- open_timeout: Timeout for opening connections in seconds.
- :obj:`None` disables the timeout.
-
- See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
- documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
- ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
-
- Any other keyword arguments are passed the event loop's
- :meth:`~asyncio.loop.create_server` method.
-
- For example:
-
- * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS.
-
- * You can set ``sock`` to a :obj:`~socket.socket` that you created
- outside of websockets.
-
- Returns:
- WebSocketServer: WebSocket server.
-
- """
-
- def __init__(
- self,
- ws_handler: Union[
- Callable[[WebSocketServerProtocol], Awaitable[Any]],
- Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated
- ],
- host: Optional[Union[str, Sequence[str]]] = None,
- port: Optional[int] = None,
- *,
- create_protocol: Optional[Callable[..., WebSocketServerProtocol]] = None,
- logger: Optional[LoggerLike] = None,
- compression: Optional[str] = "deflate",
- origins: Optional[Sequence[Optional[Origin]]] = None,
- extensions: Optional[Sequence[ServerExtensionFactory]] = None,
- subprotocols: Optional[Sequence[Subprotocol]] = None,
- extra_headers: Optional[HeadersLikeOrCallable] = None,
- server_header: Optional[str] = USER_AGENT,
- process_request: Optional[
- Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]]
- ] = None,
- select_subprotocol: Optional[
- Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol]
- ] = None,
- open_timeout: Optional[float] = 10,
- ping_interval: Optional[float] = 20,
- ping_timeout: Optional[float] = 20,
- close_timeout: Optional[float] = None,
- max_size: Optional[int] = 2**20,
- max_queue: Optional[int] = 2**5,
- read_limit: int = 2**16,
- write_limit: int = 2**16,
- **kwargs: Any,
- ) -> None:
- # Backwards compatibility: close_timeout used to be called timeout.
- timeout: Optional[float] = kwargs.pop("timeout", None)
- if timeout is None:
- timeout = 10
- else:
- warnings.warn("rename timeout to close_timeout", DeprecationWarning)
- # If both are specified, timeout is ignored.
- if close_timeout is None:
- close_timeout = timeout
-
- # Backwards compatibility: create_protocol used to be called klass.
- klass: Optional[Type[WebSocketServerProtocol]] = kwargs.pop("klass", None)
- if klass is None:
- klass = WebSocketServerProtocol
- else:
- warnings.warn("rename klass to create_protocol", DeprecationWarning)
- # If both are specified, klass is ignored.
- if create_protocol is None:
- create_protocol = klass
-
- # Backwards compatibility: recv() used to return None on closed connections
- legacy_recv: bool = kwargs.pop("legacy_recv", False)
-
- # Backwards compatibility: the loop parameter used to be supported.
- _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None)
- if _loop is None:
- loop = asyncio.get_event_loop()
- else:
- loop = _loop
- warnings.warn("remove loop argument", DeprecationWarning)
-
- ws_server = WebSocketServer(logger=logger)
-
- secure = kwargs.get("ssl") is not None
-
- if compression == "deflate":
- extensions = enable_server_permessage_deflate(extensions)
- elif compression is not None:
- raise ValueError(f"unsupported compression: {compression}")
-
- if subprotocols is not None:
- validate_subprotocols(subprotocols)
-
- factory = functools.partial(
- create_protocol,
- # For backwards compatibility with 10.0 or earlier. Done here in
- # addition to WebSocketServerProtocol to trigger the deprecation
- # warning once per serve() call rather than once per connection.
- remove_path_argument(ws_handler),
- ws_server,
- host=host,
- port=port,
- secure=secure,
- open_timeout=open_timeout,
- ping_interval=ping_interval,
- ping_timeout=ping_timeout,
- close_timeout=close_timeout,
- max_size=max_size,
- max_queue=max_queue,
- read_limit=read_limit,
- write_limit=write_limit,
- loop=_loop,
- legacy_recv=legacy_recv,
- origins=origins,
- extensions=extensions,
- subprotocols=subprotocols,
- extra_headers=extra_headers,
- server_header=server_header,
- process_request=process_request,
- select_subprotocol=select_subprotocol,
- logger=logger,
- )
-
- if kwargs.pop("unix", False):
- path: Optional[str] = kwargs.pop("path", None)
- # unix_serve(path) must not specify host and port parameters.
- assert host is None and port is None
- create_server = functools.partial(
- loop.create_unix_server, factory, path, **kwargs
- )
- else:
- create_server = functools.partial(
- loop.create_server, factory, host, port, **kwargs
- )
-
- # This is a coroutine function.
- self._create_server = create_server
- self.ws_server = ws_server
-
- # async with serve(...)
-
- async def __aenter__(self) -> WebSocketServer:
- return await self
-
- async def __aexit__(
- self,
- exc_type: Optional[Type[BaseException]],
- exc_value: Optional[BaseException],
- traceback: Optional[TracebackType],
- ) -> None:
- self.ws_server.close()
- await self.ws_server.wait_closed()
-
- # await serve(...)
-
- def __await__(self) -> Generator[Any, None, WebSocketServer]:
- # Create a suitable iterator by calling __await__ on a coroutine.
- return self.__await_impl__().__await__()
-
- async def __await_impl__(self) -> WebSocketServer:
- server = await self._create_server()
- self.ws_server.wrap(server)
- return self.ws_server
-
- # yield from serve(...) - remove when dropping Python < 3.10
-
- __iter__ = __await__
-
-
- serve = Serve
-
-
- def unix_serve(
- ws_handler: Union[
- Callable[[WebSocketServerProtocol], Awaitable[Any]],
- Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated
- ],
- path: Optional[str] = None,
- **kwargs: Any,
- ) -> Serve:
- """
- Start a WebSocket server listening on a Unix socket.
-
- This function is identical to :func:`serve`, except the ``host`` and
- ``port`` arguments are replaced by ``path``. It is only available on Unix.
-
- Unrecognized keyword arguments are passed the event loop's
- :meth:`~asyncio.loop.create_unix_server` method.
-
- It's useful for deploying a server behind a reverse proxy such as nginx.
-
- Args:
- path: File system path to the Unix socket.
-
- """
- return serve(ws_handler, path=path, unix=True, **kwargs)
-
-
- def remove_path_argument(
- ws_handler: Union[
- Callable[[WebSocketServerProtocol], Awaitable[Any]],
- Callable[[WebSocketServerProtocol, str], Awaitable[Any]],
- ]
- ) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]:
- try:
- inspect.signature(ws_handler).bind(None)
- except TypeError:
- try:
- inspect.signature(ws_handler).bind(None, "")
- except TypeError: # pragma: no cover
- # ws_handler accepts neither one nor two arguments; leave it alone.
- pass
- else:
- # ws_handler accepts two arguments; activate backwards compatibility.
-
- # Enable deprecation warning and announce deprecation in 11.0.
- # warnings.warn("remove second argument of ws_handler", DeprecationWarning)
-
- async def _ws_handler(websocket: WebSocketServerProtocol) -> Any:
- return await cast(
- Callable[[WebSocketServerProtocol, str], Awaitable[Any]],
- ws_handler,
- )(websocket, websocket.path)
-
- return _ws_handler
-
- return cast(
- Callable[[WebSocketServerProtocol], Awaitable[Any]],
- ws_handler,
- )
|