1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090 |
- import json
- import mimetypes
- import os
- import sys
- from copy import copy
- from functools import partial
- from http import HTTPStatus
- from importlib import import_module
- from io import BytesIO
- from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit
-
- from asgiref.sync import sync_to_async
-
- from django.conf import settings
- from django.core.handlers.asgi import ASGIRequest
- from django.core.handlers.base import BaseHandler
- from django.core.handlers.wsgi import WSGIRequest
- from django.core.serializers.json import DjangoJSONEncoder
- from django.core.signals import got_request_exception, request_finished, request_started
- from django.db import close_old_connections
- from django.http import HttpRequest, QueryDict, SimpleCookie
- from django.test import signals
- from django.test.utils import ContextList
- from django.urls import resolve
- from django.utils.encoding import force_bytes
- from django.utils.functional import SimpleLazyObject
- from django.utils.http import urlencode
- from django.utils.itercompat import is_iterable
- from django.utils.regex_helper import _lazy_re_compile
-
- __all__ = (
- "AsyncClient",
- "AsyncRequestFactory",
- "Client",
- "RedirectCycleError",
- "RequestFactory",
- "encode_file",
- "encode_multipart",
- )
-
-
- BOUNDARY = "BoUnDaRyStRiNg"
- MULTIPART_CONTENT = "multipart/form-data; boundary=%s" % BOUNDARY
- CONTENT_TYPE_RE = _lazy_re_compile(r".*; charset=([\w-]+);?")
- # Structured suffix spec: https://tools.ietf.org/html/rfc6838#section-4.2.8
- JSON_CONTENT_TYPE_RE = _lazy_re_compile(r"^application\/(.+\+)?json")
-
-
- class RedirectCycleError(Exception):
- """The test client has been asked to follow a redirect loop."""
-
- def __init__(self, message, last_response):
- super().__init__(message)
- self.last_response = last_response
- self.redirect_chain = last_response.redirect_chain
-
-
- class FakePayload:
- """
- A wrapper around BytesIO that restricts what can be read since data from
- the network can't be sought and cannot be read outside of its content
- length. This makes sure that views can't do anything under the test client
- that wouldn't work in real life.
- """
-
- def __init__(self, content=None):
- self.__content = BytesIO()
- self.__len = 0
- self.read_started = False
- if content is not None:
- self.write(content)
-
- def __len__(self):
- return self.__len
-
- def read(self, num_bytes=None):
- if not self.read_started:
- self.__content.seek(0)
- self.read_started = True
- if num_bytes is None:
- num_bytes = self.__len or 0
- assert (
- self.__len >= num_bytes
- ), "Cannot read more than the available bytes from the HTTP incoming data."
- content = self.__content.read(num_bytes)
- self.__len -= num_bytes
- return content
-
- def write(self, content):
- if self.read_started:
- raise ValueError("Unable to write a payload after it's been read")
- content = force_bytes(content)
- self.__content.write(content)
- self.__len += len(content)
-
-
- def closing_iterator_wrapper(iterable, close):
- try:
- yield from iterable
- finally:
- request_finished.disconnect(close_old_connections)
- close() # will fire request_finished
- request_finished.connect(close_old_connections)
-
-
- def conditional_content_removal(request, response):
- """
- Simulate the behavior of most web servers by removing the content of
- responses for HEAD requests, 1xx, 204, and 304 responses. Ensure
- compliance with RFC 7230, section 3.3.3.
- """
- if 100 <= response.status_code < 200 or response.status_code in (204, 304):
- if response.streaming:
- response.streaming_content = []
- else:
- response.content = b""
- if request.method == "HEAD":
- if response.streaming:
- response.streaming_content = []
- else:
- response.content = b""
- return response
-
-
- class ClientHandler(BaseHandler):
- """
- An HTTP Handler that can be used for testing purposes. Use the WSGI
- interface to compose requests, but return the raw HttpResponse object with
- the originating WSGIRequest attached to its ``wsgi_request`` attribute.
- """
-
- def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
- self.enforce_csrf_checks = enforce_csrf_checks
- super().__init__(*args, **kwargs)
-
- def __call__(self, environ):
- # Set up middleware if needed. We couldn't do this earlier, because
- # settings weren't available.
- if self._middleware_chain is None:
- self.load_middleware()
-
- request_started.disconnect(close_old_connections)
- request_started.send(sender=self.__class__, environ=environ)
- request_started.connect(close_old_connections)
- request = WSGIRequest(environ)
- # sneaky little hack so that we can easily get round
- # CsrfViewMiddleware. This makes life easier, and is probably
- # required for backwards compatibility with external tests against
- # admin views.
- request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
-
- # Request goes through middleware.
- response = self.get_response(request)
-
- # Simulate behaviors of most web servers.
- conditional_content_removal(request, response)
-
- # Attach the originating request to the response so that it could be
- # later retrieved.
- response.wsgi_request = request
-
- # Emulate a WSGI server by calling the close method on completion.
- if response.streaming:
- response.streaming_content = closing_iterator_wrapper(
- response.streaming_content, response.close
- )
- else:
- request_finished.disconnect(close_old_connections)
- response.close() # will fire request_finished
- request_finished.connect(close_old_connections)
-
- return response
-
-
- class AsyncClientHandler(BaseHandler):
- """An async version of ClientHandler."""
-
- def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
- self.enforce_csrf_checks = enforce_csrf_checks
- super().__init__(*args, **kwargs)
-
- async def __call__(self, scope):
- # Set up middleware if needed. We couldn't do this earlier, because
- # settings weren't available.
- if self._middleware_chain is None:
- self.load_middleware(is_async=True)
- # Extract body file from the scope, if provided.
- if "_body_file" in scope:
- body_file = scope.pop("_body_file")
- else:
- body_file = FakePayload("")
-
- request_started.disconnect(close_old_connections)
- await sync_to_async(request_started.send, thread_sensitive=False)(
- sender=self.__class__, scope=scope
- )
- request_started.connect(close_old_connections)
- request = ASGIRequest(scope, body_file)
- # Sneaky little hack so that we can easily get round
- # CsrfViewMiddleware. This makes life easier, and is probably required
- # for backwards compatibility with external tests against admin views.
- request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
- # Request goes through middleware.
- response = await self.get_response_async(request)
- # Simulate behaviors of most web servers.
- conditional_content_removal(request, response)
- # Attach the originating ASGI request to the response so that it could
- # be later retrieved.
- response.asgi_request = request
- # Emulate a server by calling the close method on completion.
- if response.streaming:
- response.streaming_content = await sync_to_async(
- closing_iterator_wrapper, thread_sensitive=False
- )(
- response.streaming_content,
- response.close,
- )
- else:
- request_finished.disconnect(close_old_connections)
- # Will fire request_finished.
- await sync_to_async(response.close, thread_sensitive=False)()
- request_finished.connect(close_old_connections)
- return response
-
-
- def store_rendered_templates(store, signal, sender, template, context, **kwargs):
- """
- Store templates and contexts that are rendered.
-
- The context is copied so that it is an accurate representation at the time
- of rendering.
- """
- store.setdefault("templates", []).append(template)
- if "context" not in store:
- store["context"] = ContextList()
- store["context"].append(copy(context))
-
-
- def encode_multipart(boundary, data):
- """
- Encode multipart POST data from a dictionary of form values.
-
- The key will be used as the form data name; the value will be transmitted
- as content. If the value is a file, the contents of the file will be sent
- as an application/octet-stream; otherwise, str(value) will be sent.
- """
- lines = []
-
- def to_bytes(s):
- return force_bytes(s, settings.DEFAULT_CHARSET)
-
- # Not by any means perfect, but good enough for our purposes.
- def is_file(thing):
- return hasattr(thing, "read") and callable(thing.read)
-
- # Each bit of the multipart form data could be either a form value or a
- # file, or a *list* of form values and/or files. Remember that HTTP field
- # names can be duplicated!
- for key, value in data.items():
- if value is None:
- raise TypeError(
- "Cannot encode None for key '%s' as POST data. Did you mean "
- "to pass an empty string or omit the value?" % key
- )
- elif is_file(value):
- lines.extend(encode_file(boundary, key, value))
- elif not isinstance(value, str) and is_iterable(value):
- for item in value:
- if is_file(item):
- lines.extend(encode_file(boundary, key, item))
- else:
- lines.extend(
- to_bytes(val)
- for val in [
- "--%s" % boundary,
- 'Content-Disposition: form-data; name="%s"' % key,
- "",
- item,
- ]
- )
- else:
- lines.extend(
- to_bytes(val)
- for val in [
- "--%s" % boundary,
- 'Content-Disposition: form-data; name="%s"' % key,
- "",
- value,
- ]
- )
-
- lines.extend(
- [
- to_bytes("--%s--" % boundary),
- b"",
- ]
- )
- return b"\r\n".join(lines)
-
-
- def encode_file(boundary, key, file):
- def to_bytes(s):
- return force_bytes(s, settings.DEFAULT_CHARSET)
-
- # file.name might not be a string. For example, it's an int for
- # tempfile.TemporaryFile().
- file_has_string_name = hasattr(file, "name") and isinstance(file.name, str)
- filename = os.path.basename(file.name) if file_has_string_name else ""
-
- if hasattr(file, "content_type"):
- content_type = file.content_type
- elif filename:
- content_type = mimetypes.guess_type(filename)[0]
- else:
- content_type = None
-
- if content_type is None:
- content_type = "application/octet-stream"
- filename = filename or key
- return [
- to_bytes("--%s" % boundary),
- to_bytes(
- 'Content-Disposition: form-data; name="%s"; filename="%s"' % (key, filename)
- ),
- to_bytes("Content-Type: %s" % content_type),
- b"",
- to_bytes(file.read()),
- ]
-
-
- class RequestFactory:
- """
- Class that lets you create mock Request objects for use in testing.
-
- Usage:
-
- rf = RequestFactory()
- get_request = rf.get('/hello/')
- post_request = rf.post('/submit/', {'foo': 'bar'})
-
- Once you have a request object you can pass it to any view function,
- just as if that view had been hooked up using a URLconf.
- """
-
- def __init__(self, *, json_encoder=DjangoJSONEncoder, **defaults):
- self.json_encoder = json_encoder
- self.defaults = defaults
- self.cookies = SimpleCookie()
- self.errors = BytesIO()
-
- def _base_environ(self, **request):
- """
- The base environment for a request.
- """
- # This is a minimal valid WSGI environ dictionary, plus:
- # - HTTP_COOKIE: for cookie support,
- # - REMOTE_ADDR: often useful, see #8551.
- # See https://www.python.org/dev/peps/pep-3333/#environ-variables
- return {
- "HTTP_COOKIE": "; ".join(
- sorted(
- "%s=%s" % (morsel.key, morsel.coded_value)
- for morsel in self.cookies.values()
- )
- ),
- "PATH_INFO": "/",
- "REMOTE_ADDR": "127.0.0.1",
- "REQUEST_METHOD": "GET",
- "SCRIPT_NAME": "",
- "SERVER_NAME": "testserver",
- "SERVER_PORT": "80",
- "SERVER_PROTOCOL": "HTTP/1.1",
- "wsgi.version": (1, 0),
- "wsgi.url_scheme": "http",
- "wsgi.input": FakePayload(b""),
- "wsgi.errors": self.errors,
- "wsgi.multiprocess": True,
- "wsgi.multithread": False,
- "wsgi.run_once": False,
- **self.defaults,
- **request,
- }
-
- def request(self, **request):
- "Construct a generic request object."
- return WSGIRequest(self._base_environ(**request))
-
- def _encode_data(self, data, content_type):
- if content_type is MULTIPART_CONTENT:
- return encode_multipart(BOUNDARY, data)
- else:
- # Encode the content so that the byte representation is correct.
- match = CONTENT_TYPE_RE.match(content_type)
- if match:
- charset = match[1]
- else:
- charset = settings.DEFAULT_CHARSET
- return force_bytes(data, encoding=charset)
-
- def _encode_json(self, data, content_type):
- """
- Return encoded JSON if data is a dict, list, or tuple and content_type
- is application/json.
- """
- should_encode = JSON_CONTENT_TYPE_RE.match(content_type) and isinstance(
- data, (dict, list, tuple)
- )
- return json.dumps(data, cls=self.json_encoder) if should_encode else data
-
- def _get_path(self, parsed):
- path = parsed.path
- # If there are parameters, add them
- if parsed.params:
- path += ";" + parsed.params
- path = unquote_to_bytes(path)
- # Replace the behavior where non-ASCII values in the WSGI environ are
- # arbitrarily decoded with ISO-8859-1.
- # Refs comment in `get_bytes_from_wsgi()`.
- return path.decode("iso-8859-1")
-
- def get(self, path, data=None, secure=False, **extra):
- """Construct a GET request."""
- data = {} if data is None else data
- return self.generic(
- "GET",
- path,
- secure=secure,
- **{
- "QUERY_STRING": urlencode(data, doseq=True),
- **extra,
- },
- )
-
- def post(
- self, path, data=None, content_type=MULTIPART_CONTENT, secure=False, **extra
- ):
- """Construct a POST request."""
- data = self._encode_json({} if data is None else data, content_type)
- post_data = self._encode_data(data, content_type)
-
- return self.generic(
- "POST", path, post_data, content_type, secure=secure, **extra
- )
-
- def head(self, path, data=None, secure=False, **extra):
- """Construct a HEAD request."""
- data = {} if data is None else data
- return self.generic(
- "HEAD",
- path,
- secure=secure,
- **{
- "QUERY_STRING": urlencode(data, doseq=True),
- **extra,
- },
- )
-
- def trace(self, path, secure=False, **extra):
- """Construct a TRACE request."""
- return self.generic("TRACE", path, secure=secure, **extra)
-
- def options(
- self,
- path,
- data="",
- content_type="application/octet-stream",
- secure=False,
- **extra,
- ):
- "Construct an OPTIONS request."
- return self.generic("OPTIONS", path, data, content_type, secure=secure, **extra)
-
- def put(
- self,
- path,
- data="",
- content_type="application/octet-stream",
- secure=False,
- **extra,
- ):
- """Construct a PUT request."""
- data = self._encode_json(data, content_type)
- return self.generic("PUT", path, data, content_type, secure=secure, **extra)
-
- def patch(
- self,
- path,
- data="",
- content_type="application/octet-stream",
- secure=False,
- **extra,
- ):
- """Construct a PATCH request."""
- data = self._encode_json(data, content_type)
- return self.generic("PATCH", path, data, content_type, secure=secure, **extra)
-
- def delete(
- self,
- path,
- data="",
- content_type="application/octet-stream",
- secure=False,
- **extra,
- ):
- """Construct a DELETE request."""
- data = self._encode_json(data, content_type)
- return self.generic("DELETE", path, data, content_type, secure=secure, **extra)
-
- def generic(
- self,
- method,
- path,
- data="",
- content_type="application/octet-stream",
- secure=False,
- **extra,
- ):
- """Construct an arbitrary HTTP request."""
- parsed = urlparse(str(path)) # path can be lazy
- data = force_bytes(data, settings.DEFAULT_CHARSET)
- r = {
- "PATH_INFO": self._get_path(parsed),
- "REQUEST_METHOD": method,
- "SERVER_PORT": "443" if secure else "80",
- "wsgi.url_scheme": "https" if secure else "http",
- }
- if data:
- r.update(
- {
- "CONTENT_LENGTH": str(len(data)),
- "CONTENT_TYPE": content_type,
- "wsgi.input": FakePayload(data),
- }
- )
- r.update(extra)
- # If QUERY_STRING is absent or empty, we want to extract it from the URL.
- if not r.get("QUERY_STRING"):
- # WSGI requires latin-1 encoded strings. See get_path_info().
- query_string = parsed[4].encode().decode("iso-8859-1")
- r["QUERY_STRING"] = query_string
- return self.request(**r)
-
-
- class AsyncRequestFactory(RequestFactory):
- """
- Class that lets you create mock ASGI-like Request objects for use in
- testing. Usage:
-
- rf = AsyncRequestFactory()
- get_request = await rf.get('/hello/')
- post_request = await rf.post('/submit/', {'foo': 'bar'})
-
- Once you have a request object you can pass it to any view function,
- including synchronous ones. The reason we have a separate class here is:
- a) this makes ASGIRequest subclasses, and
- b) AsyncTestClient can subclass it.
- """
-
- def _base_scope(self, **request):
- """The base scope for a request."""
- # This is a minimal valid ASGI scope, plus:
- # - headers['cookie'] for cookie support,
- # - 'client' often useful, see #8551.
- scope = {
- "asgi": {"version": "3.0"},
- "type": "http",
- "http_version": "1.1",
- "client": ["127.0.0.1", 0],
- "server": ("testserver", "80"),
- "scheme": "http",
- "method": "GET",
- "headers": [],
- **self.defaults,
- **request,
- }
- scope["headers"].append(
- (
- b"cookie",
- b"; ".join(
- sorted(
- ("%s=%s" % (morsel.key, morsel.coded_value)).encode("ascii")
- for morsel in self.cookies.values()
- )
- ),
- )
- )
- return scope
-
- def request(self, **request):
- """Construct a generic request object."""
- # This is synchronous, which means all methods on this class are.
- # AsyncClient, however, has an async request function, which makes all
- # its methods async.
- if "_body_file" in request:
- body_file = request.pop("_body_file")
- else:
- body_file = FakePayload("")
- return ASGIRequest(self._base_scope(**request), body_file)
-
- def generic(
- self,
- method,
- path,
- data="",
- content_type="application/octet-stream",
- secure=False,
- **extra,
- ):
- """Construct an arbitrary HTTP request."""
- parsed = urlparse(str(path)) # path can be lazy.
- data = force_bytes(data, settings.DEFAULT_CHARSET)
- s = {
- "method": method,
- "path": self._get_path(parsed),
- "server": ("127.0.0.1", "443" if secure else "80"),
- "scheme": "https" if secure else "http",
- "headers": [(b"host", b"testserver")],
- }
- if data:
- s["headers"].extend(
- [
- (b"content-length", str(len(data)).encode("ascii")),
- (b"content-type", content_type.encode("ascii")),
- ]
- )
- s["_body_file"] = FakePayload(data)
- follow = extra.pop("follow", None)
- if follow is not None:
- s["follow"] = follow
- if query_string := extra.pop("QUERY_STRING", None):
- s["query_string"] = query_string
- s["headers"] += [
- (key.lower().encode("ascii"), value.encode("latin1"))
- for key, value in extra.items()
- ]
- # If QUERY_STRING is absent or empty, we want to extract it from the
- # URL.
- if not s.get("query_string"):
- s["query_string"] = parsed[4]
- return self.request(**s)
-
-
- class ClientMixin:
- """
- Mixin with common methods between Client and AsyncClient.
- """
-
- def store_exc_info(self, **kwargs):
- """Store exceptions when they are generated by a view."""
- self.exc_info = sys.exc_info()
-
- def check_exception(self, response):
- """
- Look for a signaled exception, clear the current context exception
- data, re-raise the signaled exception, and clear the signaled exception
- from the local cache.
- """
- response.exc_info = self.exc_info
- if self.exc_info:
- _, exc_value, _ = self.exc_info
- self.exc_info = None
- if self.raise_request_exception:
- raise exc_value
-
- @property
- def session(self):
- """Return the current session variables."""
- engine = import_module(settings.SESSION_ENGINE)
- cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
- if cookie:
- return engine.SessionStore(cookie.value)
- session = engine.SessionStore()
- session.save()
- self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
- return session
-
- def login(self, **credentials):
- """
- Set the Factory to appear as if it has successfully logged into a site.
-
- Return True if login is possible or False if the provided credentials
- are incorrect.
- """
- from django.contrib.auth import authenticate
-
- user = authenticate(**credentials)
- if user:
- self._login(user)
- return True
- return False
-
- def force_login(self, user, backend=None):
- def get_backend():
- from django.contrib.auth import load_backend
-
- for backend_path in settings.AUTHENTICATION_BACKENDS:
- backend = load_backend(backend_path)
- if hasattr(backend, "get_user"):
- return backend_path
-
- if backend is None:
- backend = get_backend()
- user.backend = backend
- self._login(user, backend)
-
- def _login(self, user, backend=None):
- from django.contrib.auth import login
-
- # Create a fake request to store login details.
- request = HttpRequest()
- if self.session:
- request.session = self.session
- else:
- engine = import_module(settings.SESSION_ENGINE)
- request.session = engine.SessionStore()
- login(request, user, backend)
- # Save the session values.
- request.session.save()
- # Set the cookie to represent the session.
- session_cookie = settings.SESSION_COOKIE_NAME
- self.cookies[session_cookie] = request.session.session_key
- cookie_data = {
- "max-age": None,
- "path": "/",
- "domain": settings.SESSION_COOKIE_DOMAIN,
- "secure": settings.SESSION_COOKIE_SECURE or None,
- "expires": None,
- }
- self.cookies[session_cookie].update(cookie_data)
-
- def logout(self):
- """Log out the user by removing the cookies and session object."""
- from django.contrib.auth import get_user, logout
-
- request = HttpRequest()
- if self.session:
- request.session = self.session
- request.user = get_user(request)
- else:
- engine = import_module(settings.SESSION_ENGINE)
- request.session = engine.SessionStore()
- logout(request)
- self.cookies = SimpleCookie()
-
- def _parse_json(self, response, **extra):
- if not hasattr(response, "_json"):
- if not JSON_CONTENT_TYPE_RE.match(response.get("Content-Type")):
- raise ValueError(
- 'Content-Type header is "%s", not "application/json"'
- % response.get("Content-Type")
- )
- response._json = json.loads(
- response.content.decode(response.charset), **extra
- )
- return response._json
-
-
- class Client(ClientMixin, RequestFactory):
- """
- A class that can act as a client for testing purposes.
-
- It allows the user to compose GET and POST requests, and
- obtain the response that the server gave to those requests.
- The server Response objects are annotated with the details
- of the contexts and templates that were rendered during the
- process of serving the request.
-
- Client objects are stateful - they will retain cookie (and
- thus session) details for the lifetime of the Client instance.
-
- This is not intended as a replacement for Twill/Selenium or
- the like - it is here to allow testing against the
- contexts and templates produced by a view, rather than the
- HTML rendered to the end-user.
- """
-
- def __init__(
- self, enforce_csrf_checks=False, raise_request_exception=True, **defaults
- ):
- super().__init__(**defaults)
- self.handler = ClientHandler(enforce_csrf_checks)
- self.raise_request_exception = raise_request_exception
- self.exc_info = None
- self.extra = None
-
- def request(self, **request):
- """
- Make a generic request. Compose the environment dictionary and pass
- to the handler, return the result of the handler. Assume defaults for
- the query environment, which can be overridden using the arguments to
- the request.
- """
- environ = self._base_environ(**request)
-
- # Curry a data dictionary into an instance of the template renderer
- # callback function.
- data = {}
- on_template_render = partial(store_rendered_templates, data)
- signal_uid = "template-render-%s" % id(request)
- signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
- # Capture exceptions created by the handler.
- exception_uid = "request-exception-%s" % id(request)
- got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid)
- try:
- response = self.handler(environ)
- finally:
- signals.template_rendered.disconnect(dispatch_uid=signal_uid)
- got_request_exception.disconnect(dispatch_uid=exception_uid)
- # Check for signaled exceptions.
- self.check_exception(response)
- # Save the client and request that stimulated the response.
- response.client = self
- response.request = request
- # Add any rendered template detail to the response.
- response.templates = data.get("templates", [])
- response.context = data.get("context")
- response.json = partial(self._parse_json, response)
- # Attach the ResolverMatch instance to the response.
- urlconf = getattr(response.wsgi_request, "urlconf", None)
- response.resolver_match = SimpleLazyObject(
- lambda: resolve(request["PATH_INFO"], urlconf=urlconf),
- )
- # Flatten a single context. Not really necessary anymore thanks to the
- # __getattr__ flattening in ContextList, but has some edge case
- # backwards compatibility implications.
- if response.context and len(response.context) == 1:
- response.context = response.context[0]
- # Update persistent cookie data.
- if response.cookies:
- self.cookies.update(response.cookies)
- return response
-
- def get(self, path, data=None, follow=False, secure=False, **extra):
- """Request a response from the server using GET."""
- self.extra = extra
- response = super().get(path, data=data, secure=secure, **extra)
- if follow:
- response = self._handle_redirects(response, data=data, **extra)
- return response
-
- def post(
- self,
- path,
- data=None,
- content_type=MULTIPART_CONTENT,
- follow=False,
- secure=False,
- **extra,
- ):
- """Request a response from the server using POST."""
- self.extra = extra
- response = super().post(
- path, data=data, content_type=content_type, secure=secure, **extra
- )
- if follow:
- response = self._handle_redirects(
- response, data=data, content_type=content_type, **extra
- )
- return response
-
- def head(self, path, data=None, follow=False, secure=False, **extra):
- """Request a response from the server using HEAD."""
- self.extra = extra
- response = super().head(path, data=data, secure=secure, **extra)
- if follow:
- response = self._handle_redirects(response, data=data, **extra)
- return response
-
- def options(
- self,
- path,
- data="",
- content_type="application/octet-stream",
- follow=False,
- secure=False,
- **extra,
- ):
- """Request a response from the server using OPTIONS."""
- self.extra = extra
- response = super().options(
- path, data=data, content_type=content_type, secure=secure, **extra
- )
- if follow:
- response = self._handle_redirects(
- response, data=data, content_type=content_type, **extra
- )
- return response
-
- def put(
- self,
- path,
- data="",
- content_type="application/octet-stream",
- follow=False,
- secure=False,
- **extra,
- ):
- """Send a resource to the server using PUT."""
- self.extra = extra
- response = super().put(
- path, data=data, content_type=content_type, secure=secure, **extra
- )
- if follow:
- response = self._handle_redirects(
- response, data=data, content_type=content_type, **extra
- )
- return response
-
- def patch(
- self,
- path,
- data="",
- content_type="application/octet-stream",
- follow=False,
- secure=False,
- **extra,
- ):
- """Send a resource to the server using PATCH."""
- self.extra = extra
- response = super().patch(
- path, data=data, content_type=content_type, secure=secure, **extra
- )
- if follow:
- response = self._handle_redirects(
- response, data=data, content_type=content_type, **extra
- )
- return response
-
- def delete(
- self,
- path,
- data="",
- content_type="application/octet-stream",
- follow=False,
- secure=False,
- **extra,
- ):
- """Send a DELETE request to the server."""
- self.extra = extra
- response = super().delete(
- path, data=data, content_type=content_type, secure=secure, **extra
- )
- if follow:
- response = self._handle_redirects(
- response, data=data, content_type=content_type, **extra
- )
- return response
-
- def trace(self, path, data="", follow=False, secure=False, **extra):
- """Send a TRACE request to the server."""
- self.extra = extra
- response = super().trace(path, data=data, secure=secure, **extra)
- if follow:
- response = self._handle_redirects(response, data=data, **extra)
- return response
-
- def _handle_redirects(self, response, data="", content_type="", **extra):
- """
- Follow any redirects by requesting responses from the server using GET.
- """
- response.redirect_chain = []
- redirect_status_codes = (
- HTTPStatus.MOVED_PERMANENTLY,
- HTTPStatus.FOUND,
- HTTPStatus.SEE_OTHER,
- HTTPStatus.TEMPORARY_REDIRECT,
- HTTPStatus.PERMANENT_REDIRECT,
- )
- while response.status_code in redirect_status_codes:
- response_url = response.url
- redirect_chain = response.redirect_chain
- redirect_chain.append((response_url, response.status_code))
-
- url = urlsplit(response_url)
- if url.scheme:
- extra["wsgi.url_scheme"] = url.scheme
- if url.hostname:
- extra["SERVER_NAME"] = url.hostname
- if url.port:
- extra["SERVER_PORT"] = str(url.port)
-
- path = url.path
- # RFC 2616: bare domains without path are treated as the root.
- if not path and url.netloc:
- path = "/"
- # Prepend the request path to handle relative path redirects
- if not path.startswith("/"):
- path = urljoin(response.request["PATH_INFO"], path)
-
- if response.status_code in (
- HTTPStatus.TEMPORARY_REDIRECT,
- HTTPStatus.PERMANENT_REDIRECT,
- ):
- # Preserve request method and query string (if needed)
- # post-redirect for 307/308 responses.
- request_method = response.request["REQUEST_METHOD"].lower()
- if request_method not in ("get", "head"):
- extra["QUERY_STRING"] = url.query
- request_method = getattr(self, request_method)
- else:
- request_method = self.get
- data = QueryDict(url.query)
- content_type = None
-
- response = request_method(
- path, data=data, content_type=content_type, follow=False, **extra
- )
- response.redirect_chain = redirect_chain
-
- if redirect_chain[-1] in redirect_chain[:-1]:
- # Check that we're not redirecting to somewhere we've already
- # been to, to prevent loops.
- raise RedirectCycleError(
- "Redirect loop detected.", last_response=response
- )
- if len(redirect_chain) > 20:
- # Such a lengthy chain likely also means a loop, but one with
- # a growing path, changing view, or changing query argument;
- # 20 is the value of "network.http.redirection-limit" from Firefox.
- raise RedirectCycleError("Too many redirects.", last_response=response)
-
- return response
-
-
- class AsyncClient(ClientMixin, AsyncRequestFactory):
- """
- An async version of Client that creates ASGIRequests and calls through an
- async request path.
-
- Does not currently support "follow" on its methods.
- """
-
- def __init__(
- self, enforce_csrf_checks=False, raise_request_exception=True, **defaults
- ):
- super().__init__(**defaults)
- self.handler = AsyncClientHandler(enforce_csrf_checks)
- self.raise_request_exception = raise_request_exception
- self.exc_info = None
- self.extra = None
-
- async def request(self, **request):
- """
- Make a generic request. Compose the scope dictionary and pass to the
- handler, return the result of the handler. Assume defaults for the
- query environment, which can be overridden using the arguments to the
- request.
- """
- if "follow" in request:
- raise NotImplementedError(
- "AsyncClient request methods do not accept the follow parameter."
- )
- scope = self._base_scope(**request)
- # Curry a data dictionary into an instance of the template renderer
- # callback function.
- data = {}
- on_template_render = partial(store_rendered_templates, data)
- signal_uid = "template-render-%s" % id(request)
- signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
- # Capture exceptions created by the handler.
- exception_uid = "request-exception-%s" % id(request)
- got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid)
- try:
- response = await self.handler(scope)
- finally:
- signals.template_rendered.disconnect(dispatch_uid=signal_uid)
- got_request_exception.disconnect(dispatch_uid=exception_uid)
- # Check for signaled exceptions.
- self.check_exception(response)
- # Save the client and request that stimulated the response.
- response.client = self
- response.request = request
- # Add any rendered template detail to the response.
- response.templates = data.get("templates", [])
- response.context = data.get("context")
- response.json = partial(self._parse_json, response)
- # Attach the ResolverMatch instance to the response.
- urlconf = getattr(response.asgi_request, "urlconf", None)
- response.resolver_match = SimpleLazyObject(
- lambda: resolve(request["path"], urlconf=urlconf),
- )
- # Flatten a single context. Not really necessary anymore thanks to the
- # __getattr__ flattening in ContextList, but has some edge case
- # backwards compatibility implications.
- if response.context and len(response.context) == 1:
- response.context = response.context[0]
- # Update persistent cookie data.
- if response.cookies:
- self.cookies.update(response.cookies)
- return response
|