123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456 |
- """
- Common verification code.
- """
-
- from __future__ import annotations
-
- import ipaddress
- import re
-
- from typing import Protocol, Sequence, Union, runtime_checkable
-
- import attr
-
- from .exceptions import (
- CertificateError,
- DNSMismatch,
- IPAddressMismatch,
- Mismatch,
- SRVMismatch,
- URIMismatch,
- VerificationError,
- )
-
-
- try:
- import idna
- except ImportError:
- idna = None # type: ignore[assignment]
-
-
- @attr.s(slots=True)
- class ServiceMatch:
- """
- A match of a service id and a certificate pattern.
- """
-
- service_id: ServiceID = attr.ib()
- cert_pattern: CertificatePattern = attr.ib()
-
-
- def verify_service_identity(
- cert_patterns: Sequence[CertificatePattern],
- obligatory_ids: Sequence[ServiceID],
- optional_ids: Sequence[ServiceID],
- ) -> list[ServiceMatch]:
- """
- Verify whether *cert_patterns* are valid for *obligatory_ids* and
- *optional_ids*.
-
- *obligatory_ids* must be both present and match. *optional_ids* must match
- if a pattern of the respective type is present.
- """
- errors = []
- matches = _find_matches(cert_patterns, obligatory_ids) + _find_matches(
- cert_patterns, optional_ids
- )
-
- matched_ids = [match.service_id for match in matches]
- for i in obligatory_ids:
- if i not in matched_ids:
- errors.append(i.error_on_mismatch(mismatched_id=i))
-
- for i in optional_ids:
- # If an optional ID is not matched by a certificate pattern *but* there
- # is a pattern of the same type , it is an error and the verification
- # fails. Example: the user passes a SRV-ID for "_mail.domain.com" but
- # the certificate contains an SRV-Pattern for "_xmpp.domain.com".
- if i not in matched_ids and _contains_instance_of(
- cert_patterns, i.pattern_class
- ):
- errors.append(i.error_on_mismatch(mismatched_id=i))
-
- if errors:
- raise VerificationError(errors=errors)
-
- return matches
-
-
- def _find_matches(
- cert_patterns: Sequence[CertificatePattern],
- service_ids: Sequence[ServiceID],
- ) -> list[ServiceMatch]:
- """
- Search for matching certificate patterns and service_ids.
-
- :param service_ids: List of service IDs like DNS_ID.
- :type service_ids: `list`
- """
- matches = []
- for sid in service_ids:
- for cid in cert_patterns:
- if sid.verify(cid):
- matches.append(ServiceMatch(cert_pattern=cid, service_id=sid))
- return matches
-
-
- def _contains_instance_of(seq: Sequence[object], cl: type) -> bool:
- return any(isinstance(e, cl) for e in seq)
-
-
- def _is_ip_address(pattern: str | bytes) -> bool:
- """
- Check whether *pattern* could be/match an IP address.
-
- :param pattern: A pattern for a host name.
-
- :return: `True` if *pattern* could be an IP address, else `False`.
- """
- if isinstance(pattern, bytes):
- try:
- pattern = pattern.decode("ascii")
- except UnicodeError:
- return False
-
- try:
- int(pattern)
- return True
- except ValueError:
- pass
-
- try:
- ipaddress.ip_address(pattern.replace("*", "1"))
- except ValueError:
- return False
-
- return True
-
-
- @attr.s(slots=True)
- class DNSPattern:
- """
- A DNS pattern as extracted from certificates.
- """
-
- #: The pattern.
- pattern: bytes = attr.ib()
-
- _RE_LEGAL_CHARS = re.compile(rb"^[a-z0-9\-_.]+$")
-
- @classmethod
- def from_bytes(cls, pattern: bytes) -> DNSPattern:
- if not isinstance(pattern, bytes):
- raise TypeError("The DNS pattern must be a bytes string.")
-
- pattern = pattern.strip()
-
- if pattern == b"" or _is_ip_address(pattern) or b"\0" in pattern:
- raise CertificateError(f"Invalid DNS pattern {pattern!r}.")
-
- pattern = pattern.translate(_TRANS_TO_LOWER)
- if b"*" in pattern:
- _validate_pattern(pattern)
-
- return cls(pattern=pattern)
-
-
- @attr.s(slots=True)
- class IPAddressPattern:
- """
- An IP address pattern as extracted from certificates.
- """
-
- #: The pattern.
- pattern: ipaddress.IPv4Address | ipaddress.IPv6Address = attr.ib()
-
- @classmethod
- def from_bytes(cls, bs: bytes) -> IPAddressPattern:
- try:
- return cls(pattern=ipaddress.ip_address(bs))
- except ValueError:
- raise CertificateError(
- f"Invalid IP address pattern {bs!r}."
- ) from None
-
-
- @attr.s(slots=True)
- class URIPattern:
- """
- An URI pattern as extracted from certificates.
- """
-
- #: The pattern for the protocol part.
- protocol_pattern: bytes = attr.ib()
- #: The pattern for the DNS part.
- dns_pattern: DNSPattern = attr.ib()
-
- @classmethod
- def from_bytes(cls, pattern: bytes) -> URIPattern:
- if not isinstance(pattern, bytes):
- raise TypeError("The URI pattern must be a bytes string.")
-
- pattern = pattern.strip().translate(_TRANS_TO_LOWER)
-
- if b":" not in pattern or b"*" in pattern or _is_ip_address(pattern):
- raise CertificateError(f"Invalid URI pattern {pattern!r}.")
-
- protocol_pattern, hostname = pattern.split(b":")
-
- return cls(
- protocol_pattern=protocol_pattern,
- dns_pattern=DNSPattern.from_bytes(hostname),
- )
-
-
- @attr.s(slots=True)
- class SRVPattern:
- """
- An SRV pattern as extracted from certificates.
- """
-
- #: The pattern for the name part.
- name_pattern: bytes = attr.ib()
- #: The pattern for the DNS part.
- dns_pattern: DNSPattern = attr.ib()
-
- @classmethod
- def from_bytes(cls, pattern: bytes) -> SRVPattern:
- if not isinstance(pattern, bytes):
- raise TypeError("The SRV pattern must be a bytes string.")
-
- pattern = pattern.strip().translate(_TRANS_TO_LOWER)
-
- if (
- pattern[0] != b"_"[0]
- or b"." not in pattern
- or b"*" in pattern
- or _is_ip_address(pattern)
- ):
- raise CertificateError(f"Invalid SRV pattern {pattern!r}.")
-
- name, hostname = pattern.split(b".", 1)
- return cls(
- name_pattern=name[1:], dns_pattern=DNSPattern.from_bytes(hostname)
- )
-
-
- CertificatePattern = Union[
- SRVPattern, URIPattern, DNSPattern, IPAddressPattern
- ]
- """
- A :class:`Union` of all possible patterns that can be extracted from a
- certificate.
- """
-
-
- @runtime_checkable
- class ServiceID(Protocol):
- @property
- def pattern_class(self) -> type[CertificatePattern]:
- ...
-
- @property
- def error_on_mismatch(self) -> type[Mismatch]:
- ...
-
- def verify(self, pattern: CertificatePattern) -> bool:
- ...
-
-
- @attr.s(init=False, slots=True)
- class DNS_ID:
- """
- A DNS service ID, aka hostname.
- """
-
- hostname: bytes = attr.ib()
-
- # characters that are legal in a normalized hostname
- _RE_LEGAL_CHARS = re.compile(rb"^[a-z0-9\-_.]+$")
- pattern_class = DNSPattern
- error_on_mismatch = DNSMismatch
-
- def __init__(self, hostname: str):
- if not isinstance(hostname, str):
- raise TypeError("DNS-ID must be a text string.")
-
- hostname = hostname.strip()
- if not hostname or _is_ip_address(hostname):
- raise ValueError("Invalid DNS-ID.")
-
- if any(ord(c) > 127 for c in hostname):
- if idna:
- ascii_id = idna.encode(hostname)
- else:
- raise ImportError(
- "idna library is required for non-ASCII IDs."
- )
- else:
- ascii_id = hostname.encode("ascii")
-
- self.hostname = ascii_id.translate(_TRANS_TO_LOWER)
- if self._RE_LEGAL_CHARS.match(self.hostname) is None:
- raise ValueError("Invalid DNS-ID.")
-
- def verify(self, pattern: CertificatePattern) -> bool:
- """
- https://tools.ietf.org/search/rfc6125#section-6.4
- """
- if isinstance(pattern, self.pattern_class):
- return _hostname_matches(pattern.pattern, self.hostname)
-
- return False
-
-
- @attr.s(slots=True)
- class IPAddress_ID:
- """
- An IP address service ID.
- """
-
- ip: ipaddress.IPv4Address | ipaddress.IPv6Address = attr.ib(
- converter=ipaddress.ip_address
- )
-
- pattern_class = IPAddressPattern
- error_on_mismatch = IPAddressMismatch
-
- def verify(self, pattern: CertificatePattern) -> bool:
- """
- https://tools.ietf.org/search/rfc2818#section-3.1
- """
- if isinstance(pattern, self.pattern_class):
- return self.ip == pattern.pattern
-
- return False
-
-
- @attr.s(init=False, slots=True)
- class URI_ID:
- """
- An URI service ID.
- """
-
- protocol: bytes = attr.ib()
- dns_id: DNS_ID = attr.ib()
-
- pattern_class = URIPattern
- error_on_mismatch = URIMismatch
-
- def __init__(self, uri: str):
- if not isinstance(uri, str):
- raise TypeError("URI-ID must be a text string.")
-
- uri = uri.strip()
- if ":" not in uri or _is_ip_address(uri):
- raise ValueError("Invalid URI-ID.")
-
- prot, hostname = uri.split(":")
-
- self.protocol = prot.encode("ascii").translate(_TRANS_TO_LOWER)
- self.dns_id = DNS_ID(hostname.strip("/"))
-
- def verify(self, pattern: CertificatePattern) -> bool:
- """
- https://tools.ietf.org/search/rfc6125#section-6.5.2
- """
- if isinstance(pattern, self.pattern_class):
- return (
- pattern.protocol_pattern == self.protocol
- and self.dns_id.verify(pattern.dns_pattern)
- )
-
- return False
-
-
- @attr.s(init=False, slots=True)
- class SRV_ID:
- """
- An SRV service ID.
- """
-
- name: bytes = attr.ib()
- dns_id: DNS_ID = attr.ib()
-
- pattern_class = SRVPattern
- error_on_mismatch = SRVMismatch
-
- def __init__(self, srv: str):
- if not isinstance(srv, str):
- raise TypeError("SRV-ID must be a text string.")
-
- srv = srv.strip()
- if "." not in srv or _is_ip_address(srv) or srv[0] != "_":
- raise ValueError("Invalid SRV-ID.")
-
- name, hostname = srv.split(".", 1)
-
- self.name = name[1:].encode("ascii").translate(_TRANS_TO_LOWER)
- self.dns_id = DNS_ID(hostname)
-
- def verify(self, pattern: CertificatePattern) -> bool:
- """
- https://tools.ietf.org/search/rfc6125#section-6.5.1
- """
- if isinstance(pattern, self.pattern_class):
- return self.name == pattern.name_pattern and self.dns_id.verify(
- pattern.dns_pattern
- )
-
- return False
-
-
- def _hostname_matches(cert_pattern: bytes, actual_hostname: bytes) -> bool:
- """
- :return: `True` if *cert_pattern* matches *actual_hostname*, else `False`.
- """
- if b"*" in cert_pattern:
- cert_head, cert_tail = cert_pattern.split(b".", 1)
- actual_head, actual_tail = actual_hostname.split(b".", 1)
- if cert_tail != actual_tail:
- return False
- # No patterns for IDNA
- if actual_head.startswith(b"xn--"):
- return False
-
- return cert_head == b"*" or cert_head == actual_head
-
- return cert_pattern == actual_hostname
-
-
- def _validate_pattern(cert_pattern: bytes) -> None:
- """
- Check whether the usage of wildcards within *cert_pattern* conforms with
- our expectations.
- """
- cnt = cert_pattern.count(b"*")
- if cnt > 1:
- raise CertificateError(
- "Certificate's DNS-ID {!r} contains too many wildcards.".format(
- cert_pattern
- )
- )
- parts = cert_pattern.split(b".")
- if len(parts) < 3:
- raise CertificateError(
- "Certificate's DNS-ID {!r} has too few host components for "
- "wildcard usage.".format(cert_pattern)
- )
- # We assume there will always be only one wildcard allowed.
- if b"*" not in parts[0]:
- raise CertificateError(
- "Certificate's DNS-ID {!r} has a wildcard outside the left-most "
- "part.".format(cert_pattern)
- )
- if any(not len(p) for p in parts):
- raise CertificateError(
- "Certificate's DNS-ID {!r} contains empty parts.".format(
- cert_pattern
- )
- )
-
-
- # Ensure no locale magic interferes.
- _TRANS_TO_LOWER = bytes.maketrans(
- b"ABCDEFGHIJKLMNOPQRSTUVWXYZ", b"abcdefghijklmnopqrstuvwxyz"
- )
|