123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326 |
- import logging
- import time
- import traceback
- from urllib.parse import unquote
-
- from autobahn.twisted.websocket import (
- ConnectionDeny,
- WebSocketServerFactory,
- WebSocketServerProtocol,
- )
- from twisted.internet import defer
-
- from .utils import parse_x_forwarded_for
-
- logger = logging.getLogger(__name__)
-
-
- class WebSocketProtocol(WebSocketServerProtocol):
- """
- Protocol which supports WebSockets and forwards incoming messages to
- the websocket channels.
- """
-
- application_type = "websocket"
-
- # If we should send no more messages (e.g. we error-closed the socket)
- muted = False
-
- def onConnect(self, request):
- self.server = self.factory.server_class
- self.server.protocol_connected(self)
- self.request = request
- self.protocol_to_accept = None
- self.socket_opened = time.time()
- self.last_ping = time.time()
- try:
- # Sanitize and decode headers
- self.clean_headers = []
- for name, value in request.headers.items():
- name = name.encode("ascii")
- # Prevent CVE-2015-0219
- if b"_" in name:
- continue
- self.clean_headers.append((name.lower(), value.encode("latin1")))
- # Get client address if possible
- peer = self.transport.getPeer()
- host = self.transport.getHost()
- if hasattr(peer, "host") and hasattr(peer, "port"):
- self.client_addr = [str(peer.host), peer.port]
- self.server_addr = [str(host.host), host.port]
- else:
- self.client_addr = None
- self.server_addr = None
-
- if self.server.proxy_forwarded_address_header:
- self.client_addr, self.client_scheme = parse_x_forwarded_for(
- dict(self.clean_headers),
- self.server.proxy_forwarded_address_header,
- self.server.proxy_forwarded_port_header,
- self.server.proxy_forwarded_proto_header,
- self.client_addr,
- )
- # Decode websocket subprotocol options
- subprotocols = []
- for header, value in self.clean_headers:
- if header == b"sec-websocket-protocol":
- subprotocols = [
- x.strip() for x in unquote(value.decode("ascii")).split(",")
- ]
- # Make new application instance with scope
- self.path = request.path.encode("ascii")
- self.application_deferred = defer.maybeDeferred(
- self.server.create_application,
- self,
- {
- "type": "websocket",
- "path": unquote(self.path.decode("ascii")),
- "raw_path": self.path,
- "headers": self.clean_headers,
- "query_string": self._raw_query_string, # Passed by HTTP protocol
- "client": self.client_addr,
- "server": self.server_addr,
- "subprotocols": subprotocols,
- },
- )
- if self.application_deferred is not None:
- self.application_deferred.addCallback(self.applicationCreateWorked)
- self.application_deferred.addErrback(self.applicationCreateFailed)
- except Exception:
- # Exceptions here are not displayed right, just 500.
- # Turn them into an ERROR log.
- logger.error(traceback.format_exc())
- raise
-
- # Make a deferred and return it - we'll either call it or err it later on
- self.handshake_deferred = defer.Deferred()
- return self.handshake_deferred
-
- def applicationCreateWorked(self, application_queue):
- """
- Called when the background thread has successfully made the application
- instance.
- """
- # Store the application's queue
- self.application_queue = application_queue
- # Send over the connect message
- self.application_queue.put_nowait({"type": "websocket.connect"})
- self.server.log_action(
- "websocket",
- "connecting",
- {
- "path": self.request.path,
- "client": "%s:%s" % tuple(self.client_addr)
- if self.client_addr
- else None,
- },
- )
-
- def applicationCreateFailed(self, failure):
- """
- Called when application creation fails.
- """
- logger.error(failure)
- return failure
-
- ### Twisted event handling
-
- def onOpen(self):
- # Send news that this channel is open
- logger.debug("WebSocket %s open and established", self.client_addr)
- self.server.log_action(
- "websocket",
- "connected",
- {
- "path": self.request.path,
- "client": "%s:%s" % tuple(self.client_addr)
- if self.client_addr
- else None,
- },
- )
-
- def onMessage(self, payload, isBinary):
- # If we're muted, do nothing.
- if self.muted:
- logger.debug("Muting incoming frame on %s", self.client_addr)
- return
- logger.debug("WebSocket incoming frame on %s", self.client_addr)
- self.last_ping = time.time()
- if isBinary:
- self.application_queue.put_nowait(
- {"type": "websocket.receive", "bytes": payload}
- )
- else:
- self.application_queue.put_nowait(
- {"type": "websocket.receive", "text": payload.decode("utf8")}
- )
-
- def onClose(self, wasClean, code, reason):
- """
- Called when Twisted closes the socket.
- """
- self.server.protocol_disconnected(self)
- logger.debug("WebSocket closed for %s", self.client_addr)
- if not self.muted and hasattr(self, "application_queue"):
- self.application_queue.put_nowait(
- {"type": "websocket.disconnect", "code": code}
- )
- self.server.log_action(
- "websocket",
- "disconnected",
- {
- "path": self.request.path,
- "client": "%s:%s" % tuple(self.client_addr)
- if self.client_addr
- else None,
- },
- )
-
- ### Internal event handling
-
- def handle_reply(self, message):
- if "type" not in message:
- raise ValueError("Message has no type defined")
- if message["type"] == "websocket.accept":
- self.serverAccept(message.get("subprotocol", None))
- elif message["type"] == "websocket.close":
- if self.state == self.STATE_CONNECTING:
- self.serverReject()
- else:
- self.serverClose(code=message.get("code", None))
- elif message["type"] == "websocket.send":
- if self.state == self.STATE_CONNECTING:
- raise ValueError("Socket has not been accepted, so cannot send over it")
- if message.get("bytes", None) and message.get("text", None):
- raise ValueError(
- "Got invalid WebSocket reply message on %s - contains both bytes and text keys"
- % (message,)
- )
- if message.get("bytes", None):
- self.serverSend(message["bytes"], True)
- if message.get("text", None):
- self.serverSend(message["text"], False)
-
- def handle_exception(self, exception):
- """
- Called by the server when our application tracebacks
- """
- if hasattr(self, "handshake_deferred"):
- # If the handshake is still ongoing, we need to emit a HTTP error
- # code rather than a WebSocket one.
- self.handshake_deferred.errback(
- ConnectionDeny(code=500, reason="Internal server error")
- )
- else:
- self.sendCloseFrame(code=1011)
-
- def serverAccept(self, subprotocol=None):
- """
- Called when we get a message saying to accept the connection.
- """
- self.handshake_deferred.callback(subprotocol)
- del self.handshake_deferred
- logger.debug("WebSocket %s accepted by application", self.client_addr)
-
- def serverReject(self):
- """
- Called when we get a message saying to reject the connection.
- """
- self.handshake_deferred.errback(
- ConnectionDeny(code=403, reason="Access denied")
- )
- del self.handshake_deferred
- self.server.protocol_disconnected(self)
- logger.debug("WebSocket %s rejected by application", self.client_addr)
- self.server.log_action(
- "websocket",
- "rejected",
- {
- "path": self.request.path,
- "client": "%s:%s" % tuple(self.client_addr)
- if self.client_addr
- else None,
- },
- )
-
- def serverSend(self, content, binary=False):
- """
- Server-side channel message to send a message.
- """
- if self.state == self.STATE_CONNECTING:
- self.serverAccept()
- logger.debug("Sent WebSocket packet to client for %s", self.client_addr)
- if binary:
- self.sendMessage(content, binary)
- else:
- self.sendMessage(content.encode("utf8"), binary)
-
- def serverClose(self, code=None):
- """
- Server-side channel message to close the socket
- """
- code = 1000 if code is None else code
- self.sendClose(code=code)
-
- ### Utils
-
- def duration(self):
- """
- Returns the time since the socket was opened
- """
- return time.time() - self.socket_opened
-
- def check_timeouts(self):
- """
- Called periodically to see if we should timeout something
- """
- # Web timeout checking
- if (
- self.duration() > self.server.websocket_timeout
- and self.server.websocket_timeout >= 0
- ):
- self.serverClose()
- # Ping check
- # If we're still connecting, deny the connection
- if self.state == self.STATE_CONNECTING:
- if self.duration() > self.server.websocket_connect_timeout:
- self.serverReject()
- elif self.state == self.STATE_OPEN:
- if (time.time() - self.last_ping) > self.server.ping_interval:
- self._sendAutoPing()
- self.last_ping = time.time()
-
- def __hash__(self):
- return hash(id(self))
-
- def __eq__(self, other):
- return id(self) == id(other)
-
- def __repr__(self):
- return f"<WebSocketProtocol client={self.client_addr!r} path={self.path!r}>"
-
-
- class WebSocketFactory(WebSocketServerFactory):
- """
- Factory subclass that remembers what the "main"
- factory is, so WebSocket protocols can access it
- to get reply ID info.
- """
-
- protocol = WebSocketProtocol
-
- def __init__(self, server_class, *args, **kwargs):
- self.server_class = server_class
- WebSocketServerFactory.__init__(self, *args, **kwargs)
-
- def buildProtocol(self, addr):
- """
- Builds protocol instances. We use this to inject the factory object into the protocol.
- """
- try:
- protocol = super().buildProtocol(addr)
- protocol.factory = self
- return protocol
- except Exception:
- logger.error("Cannot build protocol: %s" % traceback.format_exc())
- raise
|