123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280 |
- from __future__ import annotations
-
- import io
- import socket
- import ssl
- import typing
-
- from ..exceptions import ProxySchemeUnsupported
-
- if typing.TYPE_CHECKING:
- from typing_extensions import Literal
-
- from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
-
-
- _SelfT = typing.TypeVar("_SelfT", bound="SSLTransport")
- _WriteBuffer = typing.Union[bytearray, memoryview]
- _ReturnValue = typing.TypeVar("_ReturnValue")
-
- SSL_BLOCKSIZE = 16384
-
-
- class SSLTransport:
- """
- The SSLTransport wraps an existing socket and establishes an SSL connection.
-
- Contrary to Python's implementation of SSLSocket, it allows you to chain
- multiple TLS connections together. It's particularly useful if you need to
- implement TLS within TLS.
-
- The class supports most of the socket API operations.
- """
-
- @staticmethod
- def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None:
- """
- Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
- for TLS in TLS.
-
- The only requirement is that the ssl_context provides the 'wrap_bio'
- methods.
- """
-
- if not hasattr(ssl_context, "wrap_bio"):
- raise ProxySchemeUnsupported(
- "TLS in TLS requires SSLContext.wrap_bio() which isn't "
- "available on non-native SSLContext"
- )
-
- def __init__(
- self,
- socket: socket.socket,
- ssl_context: ssl.SSLContext,
- server_hostname: str | None = None,
- suppress_ragged_eofs: bool = True,
- ) -> None:
- """
- Create an SSLTransport around socket using the provided ssl_context.
- """
- self.incoming = ssl.MemoryBIO()
- self.outgoing = ssl.MemoryBIO()
-
- self.suppress_ragged_eofs = suppress_ragged_eofs
- self.socket = socket
-
- self.sslobj = ssl_context.wrap_bio(
- self.incoming, self.outgoing, server_hostname=server_hostname
- )
-
- # Perform initial handshake.
- self._ssl_io_loop(self.sslobj.do_handshake)
-
- def __enter__(self: _SelfT) -> _SelfT:
- return self
-
- def __exit__(self, *_: typing.Any) -> None:
- self.close()
-
- def fileno(self) -> int:
- return self.socket.fileno()
-
- def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes:
- return self._wrap_ssl_read(len, buffer)
-
- def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes:
- if flags != 0:
- raise ValueError("non-zero flags not allowed in calls to recv")
- return self._wrap_ssl_read(buflen)
-
- def recv_into(
- self,
- buffer: _WriteBuffer,
- nbytes: int | None = None,
- flags: int = 0,
- ) -> None | int | bytes:
- if flags != 0:
- raise ValueError("non-zero flags not allowed in calls to recv_into")
- if nbytes is None:
- nbytes = len(buffer)
- return self.read(nbytes, buffer)
-
- def sendall(self, data: bytes, flags: int = 0) -> None:
- if flags != 0:
- raise ValueError("non-zero flags not allowed in calls to sendall")
- count = 0
- with memoryview(data) as view, view.cast("B") as byte_view:
- amount = len(byte_view)
- while count < amount:
- v = self.send(byte_view[count:])
- count += v
-
- def send(self, data: bytes, flags: int = 0) -> int:
- if flags != 0:
- raise ValueError("non-zero flags not allowed in calls to send")
- return self._ssl_io_loop(self.sslobj.write, data)
-
- def makefile(
- self,
- mode: str,
- buffering: int | None = None,
- *,
- encoding: str | None = None,
- errors: str | None = None,
- newline: str | None = None,
- ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO:
- """
- Python's httpclient uses makefile and buffered io when reading HTTP
- messages and we need to support it.
-
- This is unfortunately a copy and paste of socket.py makefile with small
- changes to point to the socket directly.
- """
- if not set(mode) <= {"r", "w", "b"}:
- raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)")
-
- writing = "w" in mode
- reading = "r" in mode or not writing
- assert reading or writing
- binary = "b" in mode
- rawmode = ""
- if reading:
- rawmode += "r"
- if writing:
- rawmode += "w"
- raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type]
- self.socket._io_refs += 1 # type: ignore[attr-defined]
- if buffering is None:
- buffering = -1
- if buffering < 0:
- buffering = io.DEFAULT_BUFFER_SIZE
- if buffering == 0:
- if not binary:
- raise ValueError("unbuffered streams must be binary")
- return raw
- buffer: typing.BinaryIO
- if reading and writing:
- buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment]
- elif reading:
- buffer = io.BufferedReader(raw, buffering)
- else:
- assert writing
- buffer = io.BufferedWriter(raw, buffering)
- if binary:
- return buffer
- text = io.TextIOWrapper(buffer, encoding, errors, newline)
- text.mode = mode # type: ignore[misc]
- return text
-
- def unwrap(self) -> None:
- self._ssl_io_loop(self.sslobj.unwrap)
-
- def close(self) -> None:
- self.socket.close()
-
- @typing.overload
- def getpeercert(
- self, binary_form: Literal[False] = ...
- ) -> _TYPE_PEER_CERT_RET_DICT | None:
- ...
-
- @typing.overload
- def getpeercert(self, binary_form: Literal[True]) -> bytes | None:
- ...
-
- def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
- return self.sslobj.getpeercert(binary_form) # type: ignore[return-value]
-
- def version(self) -> str | None:
- return self.sslobj.version()
-
- def cipher(self) -> tuple[str, str, int] | None:
- return self.sslobj.cipher()
-
- def selected_alpn_protocol(self) -> str | None:
- return self.sslobj.selected_alpn_protocol()
-
- def selected_npn_protocol(self) -> str | None:
- return self.sslobj.selected_npn_protocol()
-
- def shared_ciphers(self) -> list[tuple[str, str, int]] | None:
- return self.sslobj.shared_ciphers()
-
- def compression(self) -> str | None:
- return self.sslobj.compression()
-
- def settimeout(self, value: float | None) -> None:
- self.socket.settimeout(value)
-
- def gettimeout(self) -> float | None:
- return self.socket.gettimeout()
-
- def _decref_socketios(self) -> None:
- self.socket._decref_socketios() # type: ignore[attr-defined]
-
- def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes:
- try:
- return self._ssl_io_loop(self.sslobj.read, len, buffer)
- except ssl.SSLError as e:
- if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
- return 0 # eof, return 0.
- else:
- raise
-
- # func is sslobj.do_handshake or sslobj.unwrap
- @typing.overload
- def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None:
- ...
-
- # func is sslobj.write, arg1 is data
- @typing.overload
- def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int:
- ...
-
- # func is sslobj.read, arg1 is len, arg2 is buffer
- @typing.overload
- def _ssl_io_loop(
- self,
- func: typing.Callable[[int, bytearray | None], bytes],
- arg1: int,
- arg2: bytearray | None,
- ) -> bytes:
- ...
-
- def _ssl_io_loop(
- self,
- func: typing.Callable[..., _ReturnValue],
- arg1: None | bytes | int = None,
- arg2: bytearray | None = None,
- ) -> _ReturnValue:
- """Performs an I/O loop between incoming/outgoing and the socket."""
- should_loop = True
- ret = None
-
- while should_loop:
- errno = None
- try:
- if arg1 is None and arg2 is None:
- ret = func()
- elif arg2 is None:
- ret = func(arg1)
- else:
- ret = func(arg1, arg2)
- except ssl.SSLError as e:
- if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
- # WANT_READ, and WANT_WRITE are expected, others are not.
- raise e
- errno = e.errno
-
- buf = self.outgoing.read()
- self.socket.sendall(buf)
-
- if errno is None:
- should_loop = False
- elif errno == ssl.SSL_ERROR_WANT_READ:
- buf = self.socket.recv(SSL_BLOCKSIZE)
- if buf:
- self.incoming.write(buf)
- else:
- self.incoming.write_eof()
- return typing.cast(_ReturnValue, ret)
|