123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- # This file is dual licensed under the terms of the Apache License, Version
- # 2.0, and the BSD License. See the LICENSE file in the root of this repository
- # for complete details.
-
- from __future__ import annotations
-
- import base64
- import binascii
- import os
- import time
- import typing
-
- from cryptography import utils
- from cryptography.exceptions import InvalidSignature
- from cryptography.hazmat.primitives import hashes, padding
- from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
- from cryptography.hazmat.primitives.hmac import HMAC
-
-
- class InvalidToken(Exception):
- pass
-
-
- _MAX_CLOCK_SKEW = 60
-
-
- class Fernet:
- def __init__(
- self,
- key: typing.Union[bytes, str],
- backend: typing.Any = None,
- ) -> None:
- try:
- key = base64.urlsafe_b64decode(key)
- except binascii.Error as exc:
- raise ValueError(
- "Fernet key must be 32 url-safe base64-encoded bytes."
- ) from exc
- if len(key) != 32:
- raise ValueError(
- "Fernet key must be 32 url-safe base64-encoded bytes."
- )
-
- self._signing_key = key[:16]
- self._encryption_key = key[16:]
-
- @classmethod
- def generate_key(cls) -> bytes:
- return base64.urlsafe_b64encode(os.urandom(32))
-
- def encrypt(self, data: bytes) -> bytes:
- return self.encrypt_at_time(data, int(time.time()))
-
- def encrypt_at_time(self, data: bytes, current_time: int) -> bytes:
- iv = os.urandom(16)
- return self._encrypt_from_parts(data, current_time, iv)
-
- def _encrypt_from_parts(
- self, data: bytes, current_time: int, iv: bytes
- ) -> bytes:
- utils._check_bytes("data", data)
-
- padder = padding.PKCS7(algorithms.AES.block_size).padder()
- padded_data = padder.update(data) + padder.finalize()
- encryptor = Cipher(
- algorithms.AES(self._encryption_key),
- modes.CBC(iv),
- ).encryptor()
- ciphertext = encryptor.update(padded_data) + encryptor.finalize()
-
- basic_parts = (
- b"\x80"
- + current_time.to_bytes(length=8, byteorder="big")
- + iv
- + ciphertext
- )
-
- h = HMAC(self._signing_key, hashes.SHA256())
- h.update(basic_parts)
- hmac = h.finalize()
- return base64.urlsafe_b64encode(basic_parts + hmac)
-
- def decrypt(
- self, token: typing.Union[bytes, str], ttl: typing.Optional[int] = None
- ) -> bytes:
- timestamp, data = Fernet._get_unverified_token_data(token)
- if ttl is None:
- time_info = None
- else:
- time_info = (ttl, int(time.time()))
- return self._decrypt_data(data, timestamp, time_info)
-
- def decrypt_at_time(
- self, token: typing.Union[bytes, str], ttl: int, current_time: int
- ) -> bytes:
- if ttl is None:
- raise ValueError(
- "decrypt_at_time() can only be used with a non-None ttl"
- )
- timestamp, data = Fernet._get_unverified_token_data(token)
- return self._decrypt_data(data, timestamp, (ttl, current_time))
-
- def extract_timestamp(self, token: typing.Union[bytes, str]) -> int:
- timestamp, data = Fernet._get_unverified_token_data(token)
- # Verify the token was not tampered with.
- self._verify_signature(data)
- return timestamp
-
- @staticmethod
- def _get_unverified_token_data(
- token: typing.Union[bytes, str]
- ) -> typing.Tuple[int, bytes]:
- if not isinstance(token, (str, bytes)):
- raise TypeError("token must be bytes or str")
-
- try:
- data = base64.urlsafe_b64decode(token)
- except (TypeError, binascii.Error):
- raise InvalidToken
-
- if not data or data[0] != 0x80:
- raise InvalidToken
-
- if len(data) < 9:
- raise InvalidToken
-
- timestamp = int.from_bytes(data[1:9], byteorder="big")
- return timestamp, data
-
- def _verify_signature(self, data: bytes) -> None:
- h = HMAC(self._signing_key, hashes.SHA256())
- h.update(data[:-32])
- try:
- h.verify(data[-32:])
- except InvalidSignature:
- raise InvalidToken
-
- def _decrypt_data(
- self,
- data: bytes,
- timestamp: int,
- time_info: typing.Optional[typing.Tuple[int, int]],
- ) -> bytes:
- if time_info is not None:
- ttl, current_time = time_info
- if timestamp + ttl < current_time:
- raise InvalidToken
-
- if current_time + _MAX_CLOCK_SKEW < timestamp:
- raise InvalidToken
-
- self._verify_signature(data)
-
- iv = data[9:25]
- ciphertext = data[25:-32]
- decryptor = Cipher(
- algorithms.AES(self._encryption_key), modes.CBC(iv)
- ).decryptor()
- plaintext_padded = decryptor.update(ciphertext)
- try:
- plaintext_padded += decryptor.finalize()
- except ValueError:
- raise InvalidToken
- unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
-
- unpadded = unpadder.update(plaintext_padded)
- try:
- unpadded += unpadder.finalize()
- except ValueError:
- raise InvalidToken
- return unpadded
-
-
- class MultiFernet:
- def __init__(self, fernets: typing.Iterable[Fernet]):
- fernets = list(fernets)
- if not fernets:
- raise ValueError(
- "MultiFernet requires at least one Fernet instance"
- )
- self._fernets = fernets
-
- def encrypt(self, msg: bytes) -> bytes:
- return self.encrypt_at_time(msg, int(time.time()))
-
- def encrypt_at_time(self, msg: bytes, current_time: int) -> bytes:
- return self._fernets[0].encrypt_at_time(msg, current_time)
-
- def rotate(self, msg: typing.Union[bytes, str]) -> bytes:
- timestamp, data = Fernet._get_unverified_token_data(msg)
- for f in self._fernets:
- try:
- p = f._decrypt_data(data, timestamp, None)
- break
- except InvalidToken:
- pass
- else:
- raise InvalidToken
-
- iv = os.urandom(16)
- return self._fernets[0]._encrypt_from_parts(p, timestamp, iv)
-
- def decrypt(
- self, msg: typing.Union[bytes, str], ttl: typing.Optional[int] = None
- ) -> bytes:
- for f in self._fernets:
- try:
- return f.decrypt(msg, ttl)
- except InvalidToken:
- pass
- raise InvalidToken
-
- def decrypt_at_time(
- self, msg: typing.Union[bytes, str], ttl: int, current_time: int
- ) -> bytes:
- for f in self._fernets:
- try:
- return f.decrypt_at_time(msg, ttl, current_time)
- except InvalidToken:
- pass
- raise InvalidToken
|