""" Common verification code. """ from __future__ import annotations import ipaddress import re from typing import Protocol, Sequence, Union, runtime_checkable import attr from .exceptions import ( CertificateError, DNSMismatch, IPAddressMismatch, Mismatch, SRVMismatch, URIMismatch, VerificationError, ) try: import idna except ImportError: idna = None # type: ignore[assignment] @attr.s(slots=True) class ServiceMatch: """ A match of a service id and a certificate pattern. """ service_id: ServiceID = attr.ib() cert_pattern: CertificatePattern = attr.ib() def verify_service_identity( cert_patterns: Sequence[CertificatePattern], obligatory_ids: Sequence[ServiceID], optional_ids: Sequence[ServiceID], ) -> list[ServiceMatch]: """ Verify whether *cert_patterns* are valid for *obligatory_ids* and *optional_ids*. *obligatory_ids* must be both present and match. *optional_ids* must match if a pattern of the respective type is present. """ errors = [] matches = _find_matches(cert_patterns, obligatory_ids) + _find_matches( cert_patterns, optional_ids ) matched_ids = [match.service_id for match in matches] for i in obligatory_ids: if i not in matched_ids: errors.append(i.error_on_mismatch(mismatched_id=i)) for i in optional_ids: # If an optional ID is not matched by a certificate pattern *but* there # is a pattern of the same type , it is an error and the verification # fails. Example: the user passes a SRV-ID for "_mail.domain.com" but # the certificate contains an SRV-Pattern for "_xmpp.domain.com". if i not in matched_ids and _contains_instance_of( cert_patterns, i.pattern_class ): errors.append(i.error_on_mismatch(mismatched_id=i)) if errors: raise VerificationError(errors=errors) return matches def _find_matches( cert_patterns: Sequence[CertificatePattern], service_ids: Sequence[ServiceID], ) -> list[ServiceMatch]: """ Search for matching certificate patterns and service_ids. :param service_ids: List of service IDs like DNS_ID. :type service_ids: `list` """ matches = [] for sid in service_ids: for cid in cert_patterns: if sid.verify(cid): matches.append(ServiceMatch(cert_pattern=cid, service_id=sid)) return matches def _contains_instance_of(seq: Sequence[object], cl: type) -> bool: return any(isinstance(e, cl) for e in seq) def _is_ip_address(pattern: str | bytes) -> bool: """ Check whether *pattern* could be/match an IP address. :param pattern: A pattern for a host name. :return: `True` if *pattern* could be an IP address, else `False`. """ if isinstance(pattern, bytes): try: pattern = pattern.decode("ascii") except UnicodeError: return False try: int(pattern) return True except ValueError: pass try: ipaddress.ip_address(pattern.replace("*", "1")) except ValueError: return False return True @attr.s(slots=True) class DNSPattern: """ A DNS pattern as extracted from certificates. """ #: The pattern. pattern: bytes = attr.ib() _RE_LEGAL_CHARS = re.compile(rb"^[a-z0-9\-_.]+$") @classmethod def from_bytes(cls, pattern: bytes) -> DNSPattern: if not isinstance(pattern, bytes): raise TypeError("The DNS pattern must be a bytes string.") pattern = pattern.strip() if pattern == b"" or _is_ip_address(pattern) or b"\0" in pattern: raise CertificateError(f"Invalid DNS pattern {pattern!r}.") pattern = pattern.translate(_TRANS_TO_LOWER) if b"*" in pattern: _validate_pattern(pattern) return cls(pattern=pattern) @attr.s(slots=True) class IPAddressPattern: """ An IP address pattern as extracted from certificates. """ #: The pattern. pattern: ipaddress.IPv4Address | ipaddress.IPv6Address = attr.ib() @classmethod def from_bytes(cls, bs: bytes) -> IPAddressPattern: try: return cls(pattern=ipaddress.ip_address(bs)) except ValueError: raise CertificateError( f"Invalid IP address pattern {bs!r}." ) from None @attr.s(slots=True) class URIPattern: """ An URI pattern as extracted from certificates. """ #: The pattern for the protocol part. protocol_pattern: bytes = attr.ib() #: The pattern for the DNS part. dns_pattern: DNSPattern = attr.ib() @classmethod def from_bytes(cls, pattern: bytes) -> URIPattern: if not isinstance(pattern, bytes): raise TypeError("The URI pattern must be a bytes string.") pattern = pattern.strip().translate(_TRANS_TO_LOWER) if b":" not in pattern or b"*" in pattern or _is_ip_address(pattern): raise CertificateError(f"Invalid URI pattern {pattern!r}.") protocol_pattern, hostname = pattern.split(b":") return cls( protocol_pattern=protocol_pattern, dns_pattern=DNSPattern.from_bytes(hostname), ) @attr.s(slots=True) class SRVPattern: """ An SRV pattern as extracted from certificates. """ #: The pattern for the name part. name_pattern: bytes = attr.ib() #: The pattern for the DNS part. dns_pattern: DNSPattern = attr.ib() @classmethod def from_bytes(cls, pattern: bytes) -> SRVPattern: if not isinstance(pattern, bytes): raise TypeError("The SRV pattern must be a bytes string.") pattern = pattern.strip().translate(_TRANS_TO_LOWER) if ( pattern[0] != b"_"[0] or b"." not in pattern or b"*" in pattern or _is_ip_address(pattern) ): raise CertificateError(f"Invalid SRV pattern {pattern!r}.") name, hostname = pattern.split(b".", 1) return cls( name_pattern=name[1:], dns_pattern=DNSPattern.from_bytes(hostname) ) CertificatePattern = Union[ SRVPattern, URIPattern, DNSPattern, IPAddressPattern ] """ A :class:`Union` of all possible patterns that can be extracted from a certificate. """ @runtime_checkable class ServiceID(Protocol): @property def pattern_class(self) -> type[CertificatePattern]: ... @property def error_on_mismatch(self) -> type[Mismatch]: ... def verify(self, pattern: CertificatePattern) -> bool: ... @attr.s(init=False, slots=True) class DNS_ID: """ A DNS service ID, aka hostname. """ hostname: bytes = attr.ib() # characters that are legal in a normalized hostname _RE_LEGAL_CHARS = re.compile(rb"^[a-z0-9\-_.]+$") pattern_class = DNSPattern error_on_mismatch = DNSMismatch def __init__(self, hostname: str): if not isinstance(hostname, str): raise TypeError("DNS-ID must be a text string.") hostname = hostname.strip() if not hostname or _is_ip_address(hostname): raise ValueError("Invalid DNS-ID.") if any(ord(c) > 127 for c in hostname): if idna: ascii_id = idna.encode(hostname) else: raise ImportError( "idna library is required for non-ASCII IDs." ) else: ascii_id = hostname.encode("ascii") self.hostname = ascii_id.translate(_TRANS_TO_LOWER) if self._RE_LEGAL_CHARS.match(self.hostname) is None: raise ValueError("Invalid DNS-ID.") def verify(self, pattern: CertificatePattern) -> bool: """ https://tools.ietf.org/search/rfc6125#section-6.4 """ if isinstance(pattern, self.pattern_class): return _hostname_matches(pattern.pattern, self.hostname) return False @attr.s(slots=True) class IPAddress_ID: """ An IP address service ID. """ ip: ipaddress.IPv4Address | ipaddress.IPv6Address = attr.ib( converter=ipaddress.ip_address ) pattern_class = IPAddressPattern error_on_mismatch = IPAddressMismatch def verify(self, pattern: CertificatePattern) -> bool: """ https://tools.ietf.org/search/rfc2818#section-3.1 """ if isinstance(pattern, self.pattern_class): return self.ip == pattern.pattern return False @attr.s(init=False, slots=True) class URI_ID: """ An URI service ID. """ protocol: bytes = attr.ib() dns_id: DNS_ID = attr.ib() pattern_class = URIPattern error_on_mismatch = URIMismatch def __init__(self, uri: str): if not isinstance(uri, str): raise TypeError("URI-ID must be a text string.") uri = uri.strip() if ":" not in uri or _is_ip_address(uri): raise ValueError("Invalid URI-ID.") prot, hostname = uri.split(":") self.protocol = prot.encode("ascii").translate(_TRANS_TO_LOWER) self.dns_id = DNS_ID(hostname.strip("/")) def verify(self, pattern: CertificatePattern) -> bool: """ https://tools.ietf.org/search/rfc6125#section-6.5.2 """ if isinstance(pattern, self.pattern_class): return ( pattern.protocol_pattern == self.protocol and self.dns_id.verify(pattern.dns_pattern) ) return False @attr.s(init=False, slots=True) class SRV_ID: """ An SRV service ID. """ name: bytes = attr.ib() dns_id: DNS_ID = attr.ib() pattern_class = SRVPattern error_on_mismatch = SRVMismatch def __init__(self, srv: str): if not isinstance(srv, str): raise TypeError("SRV-ID must be a text string.") srv = srv.strip() if "." not in srv or _is_ip_address(srv) or srv[0] != "_": raise ValueError("Invalid SRV-ID.") name, hostname = srv.split(".", 1) self.name = name[1:].encode("ascii").translate(_TRANS_TO_LOWER) self.dns_id = DNS_ID(hostname) def verify(self, pattern: CertificatePattern) -> bool: """ https://tools.ietf.org/search/rfc6125#section-6.5.1 """ if isinstance(pattern, self.pattern_class): return self.name == pattern.name_pattern and self.dns_id.verify( pattern.dns_pattern ) return False def _hostname_matches(cert_pattern: bytes, actual_hostname: bytes) -> bool: """ :return: `True` if *cert_pattern* matches *actual_hostname*, else `False`. """ if b"*" in cert_pattern: cert_head, cert_tail = cert_pattern.split(b".", 1) actual_head, actual_tail = actual_hostname.split(b".", 1) if cert_tail != actual_tail: return False # No patterns for IDNA if actual_head.startswith(b"xn--"): return False return cert_head == b"*" or cert_head == actual_head return cert_pattern == actual_hostname def _validate_pattern(cert_pattern: bytes) -> None: """ Check whether the usage of wildcards within *cert_pattern* conforms with our expectations. """ cnt = cert_pattern.count(b"*") if cnt > 1: raise CertificateError( "Certificate's DNS-ID {!r} contains too many wildcards.".format( cert_pattern ) ) parts = cert_pattern.split(b".") if len(parts) < 3: raise CertificateError( "Certificate's DNS-ID {!r} has too few host components for " "wildcard usage.".format(cert_pattern) ) # We assume there will always be only one wildcard allowed. if b"*" not in parts[0]: raise CertificateError( "Certificate's DNS-ID {!r} has a wildcard outside the left-most " "part.".format(cert_pattern) ) if any(not len(p) for p in parts): raise CertificateError( "Certificate's DNS-ID {!r} contains empty parts.".format( cert_pattern ) ) # Ensure no locale magic interferes. _TRANS_TO_LOWER = bytes.maketrans( b"ABCDEFGHIJKLMNOPQRSTUVWXYZ", b"abcdefghijklmnopqrstuvwxyz" )