1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516 |
- import difflib
- import json
- import posixpath
- import sys
- import threading
- import unittest
- import warnings
- from collections import Counter
- from contextlib import contextmanager
- from copy import copy
- from difflib import get_close_matches
- from functools import wraps
- from unittest.util import safe_repr
- from urllib.parse import (
- parse_qsl, unquote, urlencode, urljoin, urlparse, urlsplit, urlunparse,
- )
- from urllib.request import url2pathname
-
- from django.apps import apps
- from django.conf import settings
- from django.core import mail
- from django.core.exceptions import ImproperlyConfigured, ValidationError
- from django.core.files import locks
- from django.core.handlers.wsgi import WSGIHandler, get_path_info
- from django.core.management import call_command
- from django.core.management.color import no_style
- from django.core.management.sql import emit_post_migrate_signal
- from django.core.servers.basehttp import ThreadedWSGIServer, WSGIRequestHandler
- from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
- from django.forms.fields import CharField
- from django.http import QueryDict
- from django.http.request import split_domain_port, validate_host
- from django.test.client import Client
- from django.test.html import HTMLParseError, parse_html
- from django.test.signals import setting_changed, template_rendered
- from django.test.utils import (
- CaptureQueriesContext, ContextList, compare_xml, modify_settings,
- override_settings,
- )
- from django.utils.decorators import classproperty
- from django.utils.deprecation import RemovedInDjango31Warning
- from django.views.static import serve
-
- __all__ = ('TestCase', 'TransactionTestCase',
- 'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature')
-
-
- def to_list(value):
- """
- Put value into a list if it's not already one. Return an empty list if
- value is None.
- """
- if value is None:
- value = []
- elif not isinstance(value, list):
- value = [value]
- return value
-
-
- def assert_and_parse_html(self, html, user_msg, msg):
- try:
- dom = parse_html(html)
- except HTMLParseError as e:
- standardMsg = '%s\n%s' % (msg, e)
- self.fail(self._formatMessage(user_msg, standardMsg))
- return dom
-
-
- class _AssertNumQueriesContext(CaptureQueriesContext):
- def __init__(self, test_case, num, connection):
- self.test_case = test_case
- self.num = num
- super().__init__(connection)
-
- def __exit__(self, exc_type, exc_value, traceback):
- super().__exit__(exc_type, exc_value, traceback)
- if exc_type is not None:
- return
- executed = len(self)
- self.test_case.assertEqual(
- executed, self.num,
- "%d queries executed, %d expected\nCaptured queries were:\n%s" % (
- executed, self.num,
- '\n'.join(
- '%d. %s' % (i, query['sql']) for i, query in enumerate(self.captured_queries, start=1)
- )
- )
- )
-
-
- class _AssertTemplateUsedContext:
- def __init__(self, test_case, template_name):
- self.test_case = test_case
- self.template_name = template_name
- self.rendered_templates = []
- self.rendered_template_names = []
- self.context = ContextList()
-
- def on_template_render(self, sender, signal, template, context, **kwargs):
- self.rendered_templates.append(template)
- self.rendered_template_names.append(template.name)
- self.context.append(copy(context))
-
- def test(self):
- return self.template_name in self.rendered_template_names
-
- def message(self):
- return '%s was not rendered.' % self.template_name
-
- def __enter__(self):
- template_rendered.connect(self.on_template_render)
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- template_rendered.disconnect(self.on_template_render)
- if exc_type is not None:
- return
-
- if not self.test():
- message = self.message()
- if self.rendered_templates:
- message += ' Following templates were rendered: %s' % (
- ', '.join(self.rendered_template_names)
- )
- else:
- message += ' No template was rendered.'
- self.test_case.fail(message)
-
-
- class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext):
- def test(self):
- return self.template_name not in self.rendered_template_names
-
- def message(self):
- return '%s was rendered.' % self.template_name
-
-
- class _DatabaseFailure:
- def __init__(self, wrapped, message):
- self.wrapped = wrapped
- self.message = message
-
- def __call__(self):
- raise AssertionError(self.message)
-
-
- class _SimpleTestCaseDatabasesDescriptor:
- """Descriptor for SimpleTestCase.allow_database_queries deprecation."""
- def __get__(self, instance, cls=None):
- try:
- allow_database_queries = cls.allow_database_queries
- except AttributeError:
- pass
- else:
- msg = (
- '`SimpleTestCase.allow_database_queries` is deprecated. '
- 'Restrict the databases available during the execution of '
- '%s.%s with the `databases` attribute instead.'
- ) % (cls.__module__, cls.__qualname__)
- warnings.warn(msg, RemovedInDjango31Warning)
- if allow_database_queries:
- return {DEFAULT_DB_ALIAS}
- return set()
-
-
- class SimpleTestCase(unittest.TestCase):
-
- # The class we'll use for the test client self.client.
- # Can be overridden in derived classes.
- client_class = Client
- _overridden_settings = None
- _modified_settings = None
-
- databases = _SimpleTestCaseDatabasesDescriptor()
- _disallowed_database_msg = (
- 'Database %(operation)s to %(alias)r are not allowed in SimpleTestCase '
- 'subclasses. Either subclass TestCase or TransactionTestCase to ensure '
- 'proper test isolation or add %(alias)r to %(test)s.databases to silence '
- 'this failure.'
- )
- _disallowed_connection_methods = [
- ('connect', 'connections'),
- ('temporary_connection', 'connections'),
- ('cursor', 'queries'),
- ('chunked_cursor', 'queries'),
- ]
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- if cls._overridden_settings:
- cls._cls_overridden_context = override_settings(**cls._overridden_settings)
- cls._cls_overridden_context.enable()
- if cls._modified_settings:
- cls._cls_modified_context = modify_settings(cls._modified_settings)
- cls._cls_modified_context.enable()
- cls._add_databases_failures()
-
- @classmethod
- def _validate_databases(cls):
- if cls.databases == '__all__':
- return frozenset(connections)
- for alias in cls.databases:
- if alias not in connections:
- message = '%s.%s.databases refers to %r which is not defined in settings.DATABASES.' % (
- cls.__module__,
- cls.__qualname__,
- alias,
- )
- close_matches = get_close_matches(alias, list(connections))
- if close_matches:
- message += ' Did you mean %r?' % close_matches[0]
- raise ImproperlyConfigured(message)
- return frozenset(cls.databases)
-
- @classmethod
- def _add_databases_failures(cls):
- cls.databases = cls._validate_databases()
- for alias in connections:
- if alias in cls.databases:
- continue
- connection = connections[alias]
- for name, operation in cls._disallowed_connection_methods:
- message = cls._disallowed_database_msg % {
- 'test': '%s.%s' % (cls.__module__, cls.__qualname__),
- 'alias': alias,
- 'operation': operation,
- }
- method = getattr(connection, name)
- setattr(connection, name, _DatabaseFailure(method, message))
-
- @classmethod
- def _remove_databases_failures(cls):
- for alias in connections:
- if alias in cls.databases:
- continue
- connection = connections[alias]
- for name, _ in cls._disallowed_connection_methods:
- method = getattr(connection, name)
- setattr(connection, name, method.wrapped)
-
- @classmethod
- def tearDownClass(cls):
- cls._remove_databases_failures()
- if hasattr(cls, '_cls_modified_context'):
- cls._cls_modified_context.disable()
- delattr(cls, '_cls_modified_context')
- if hasattr(cls, '_cls_overridden_context'):
- cls._cls_overridden_context.disable()
- delattr(cls, '_cls_overridden_context')
- super().tearDownClass()
-
- def __call__(self, result=None):
- """
- Wrapper around default __call__ method to perform common Django test
- set up. This means that user-defined Test Cases aren't required to
- include a call to super().setUp().
- """
- testMethod = getattr(self, self._testMethodName)
- skipped = (
- getattr(self.__class__, "__unittest_skip__", False) or
- getattr(testMethod, "__unittest_skip__", False)
- )
-
- if not skipped:
- try:
- self._pre_setup()
- except Exception:
- result.addError(self, sys.exc_info())
- return
- super().__call__(result)
- if not skipped:
- try:
- self._post_teardown()
- except Exception:
- result.addError(self, sys.exc_info())
- return
-
- def _pre_setup(self):
- """
- Perform pre-test setup:
- * Create a test client.
- * Clear the mail test outbox.
- """
- self.client = self.client_class()
- mail.outbox = []
-
- def _post_teardown(self):
- """Perform post-test things."""
- pass
-
- def settings(self, **kwargs):
- """
- A context manager that temporarily sets a setting and reverts to the
- original value when exiting the context.
- """
- return override_settings(**kwargs)
-
- def modify_settings(self, **kwargs):
- """
- A context manager that temporarily applies changes a list setting and
- reverts back to the original value when exiting the context.
- """
- return modify_settings(**kwargs)
-
- def assertRedirects(self, response, expected_url, status_code=302,
- target_status_code=200, msg_prefix='',
- fetch_redirect_response=True):
- """
- Assert that a response redirected to a specific URL and that the
- redirect URL can be loaded.
-
- Won't work for external links since it uses the test client to do a
- request (use fetch_redirect_response=False to check such links without
- fetching them).
- """
- if msg_prefix:
- msg_prefix += ": "
-
- if hasattr(response, 'redirect_chain'):
- # The request was a followed redirect
- self.assertTrue(
- response.redirect_chain,
- msg_prefix + "Response didn't redirect as expected: Response code was %d (expected %d)"
- % (response.status_code, status_code)
- )
-
- self.assertEqual(
- response.redirect_chain[0][1], status_code,
- msg_prefix + "Initial response didn't redirect as expected: Response code was %d (expected %d)"
- % (response.redirect_chain[0][1], status_code)
- )
-
- url, status_code = response.redirect_chain[-1]
- scheme, netloc, path, query, fragment = urlsplit(url)
-
- self.assertEqual(
- response.status_code, target_status_code,
- msg_prefix + "Response didn't redirect as expected: Final Response code was %d (expected %d)"
- % (response.status_code, target_status_code)
- )
-
- else:
- # Not a followed redirect
- self.assertEqual(
- response.status_code, status_code,
- msg_prefix + "Response didn't redirect as expected: Response code was %d (expected %d)"
- % (response.status_code, status_code)
- )
-
- url = response.url
- scheme, netloc, path, query, fragment = urlsplit(url)
-
- # Prepend the request path to handle relative path redirects.
- if not path.startswith('/'):
- url = urljoin(response.request['PATH_INFO'], url)
- path = urljoin(response.request['PATH_INFO'], path)
-
- if fetch_redirect_response:
- # netloc might be empty, or in cases where Django tests the
- # HTTP scheme, the convention is for netloc to be 'testserver'.
- # Trust both as "internal" URLs here.
- domain, port = split_domain_port(netloc)
- if domain and not validate_host(domain, settings.ALLOWED_HOSTS):
- raise ValueError(
- "The test client is unable to fetch remote URLs (got %s). "
- "If the host is served by Django, add '%s' to ALLOWED_HOSTS. "
- "Otherwise, use assertRedirects(..., fetch_redirect_response=False)."
- % (url, domain)
- )
- redirect_response = response.client.get(path, QueryDict(query), secure=(scheme == 'https'))
-
- # Get the redirection page, using the same client that was used
- # to obtain the original response.
- self.assertEqual(
- redirect_response.status_code, target_status_code,
- msg_prefix + "Couldn't retrieve redirection page '%s': response code was %d (expected %d)"
- % (path, redirect_response.status_code, target_status_code)
- )
-
- self.assertURLEqual(
- url, expected_url,
- msg_prefix + "Response redirected to '%s', expected '%s'" % (url, expected_url)
- )
-
- def assertURLEqual(self, url1, url2, msg_prefix=''):
- """
- Assert that two URLs are the same, ignoring the order of query string
- parameters except for parameters with the same name.
-
- For example, /path/?x=1&y=2 is equal to /path/?y=2&x=1, but
- /path/?a=1&a=2 isn't equal to /path/?a=2&a=1.
- """
- def normalize(url):
- """Sort the URL's query string parameters."""
- url = str(url) # Coerce reverse_lazy() URLs.
- scheme, netloc, path, params, query, fragment = urlparse(url)
- query_parts = sorted(parse_qsl(query))
- return urlunparse((scheme, netloc, path, params, urlencode(query_parts), fragment))
-
- self.assertEqual(
- normalize(url1), normalize(url2),
- msg_prefix + "Expected '%s' to equal '%s'." % (url1, url2)
- )
-
- def _assert_contains(self, response, text, status_code, msg_prefix, html):
- # If the response supports deferred rendering and hasn't been rendered
- # yet, then ensure that it does get rendered before proceeding further.
- if hasattr(response, 'render') and callable(response.render) and not response.is_rendered:
- response.render()
-
- if msg_prefix:
- msg_prefix += ": "
-
- self.assertEqual(
- response.status_code, status_code,
- msg_prefix + "Couldn't retrieve content: Response code was %d"
- " (expected %d)" % (response.status_code, status_code)
- )
-
- if response.streaming:
- content = b''.join(response.streaming_content)
- else:
- content = response.content
- if not isinstance(text, bytes) or html:
- text = str(text)
- content = content.decode(response.charset)
- text_repr = "'%s'" % text
- else:
- text_repr = repr(text)
- if html:
- content = assert_and_parse_html(self, content, None, "Response's content is not valid HTML:")
- text = assert_and_parse_html(self, text, None, "Second argument is not valid HTML:")
- real_count = content.count(text)
- return (text_repr, real_count, msg_prefix)
-
- def assertContains(self, response, text, count=None, status_code=200, msg_prefix='', html=False):
- """
- Assert that a response indicates that some content was retrieved
- successfully, (i.e., the HTTP status code was as expected) and that
- ``text`` occurs ``count`` times in the content of the response.
- If ``count`` is None, the count doesn't matter - the assertion is true
- if the text occurs at least once in the response.
- """
- text_repr, real_count, msg_prefix = self._assert_contains(
- response, text, status_code, msg_prefix, html)
-
- if count is not None:
- self.assertEqual(
- real_count, count,
- msg_prefix + "Found %d instances of %s in response (expected %d)" % (real_count, text_repr, count)
- )
- else:
- self.assertTrue(real_count != 0, msg_prefix + "Couldn't find %s in response" % text_repr)
-
- def assertNotContains(self, response, text, status_code=200, msg_prefix='', html=False):
- """
- Assert that a response indicates that some content was retrieved
- successfully, (i.e., the HTTP status code was as expected) and that
- ``text`` doesn't occurs in the content of the response.
- """
- text_repr, real_count, msg_prefix = self._assert_contains(
- response, text, status_code, msg_prefix, html)
-
- self.assertEqual(real_count, 0, msg_prefix + "Response should not contain %s" % text_repr)
-
- def assertFormError(self, response, form, field, errors, msg_prefix=''):
- """
- Assert that a form used to render the response has a specific field
- error.
- """
- if msg_prefix:
- msg_prefix += ": "
-
- # Put context(s) into a list to simplify processing.
- contexts = to_list(response.context)
- if not contexts:
- self.fail(msg_prefix + "Response did not use any contexts to render the response")
-
- # Put error(s) into a list to simplify processing.
- errors = to_list(errors)
-
- # Search all contexts for the error.
- found_form = False
- for i, context in enumerate(contexts):
- if form not in context:
- continue
- found_form = True
- for err in errors:
- if field:
- if field in context[form].errors:
- field_errors = context[form].errors[field]
- self.assertTrue(
- err in field_errors,
- msg_prefix + "The field '%s' on form '%s' in"
- " context %d does not contain the error '%s'"
- " (actual errors: %s)" %
- (field, form, i, err, repr(field_errors))
- )
- elif field in context[form].fields:
- self.fail(
- msg_prefix + "The field '%s' on form '%s' in context %d contains no errors" %
- (field, form, i)
- )
- else:
- self.fail(
- msg_prefix + "The form '%s' in context %d does not contain the field '%s'" %
- (form, i, field)
- )
- else:
- non_field_errors = context[form].non_field_errors()
- self.assertTrue(
- err in non_field_errors,
- msg_prefix + "The form '%s' in context %d does not"
- " contain the non-field error '%s'"
- " (actual errors: %s)" %
- (form, i, err, non_field_errors or 'none')
- )
- if not found_form:
- self.fail(msg_prefix + "The form '%s' was not used to render the response" % form)
-
- def assertFormsetError(self, response, formset, form_index, field, errors,
- msg_prefix=''):
- """
- Assert that a formset used to render the response has a specific error.
-
- For field errors, specify the ``form_index`` and the ``field``.
- For non-field errors, specify the ``form_index`` and the ``field`` as
- None.
- For non-form errors, specify ``form_index`` as None and the ``field``
- as None.
- """
- # Add punctuation to msg_prefix
- if msg_prefix:
- msg_prefix += ": "
-
- # Put context(s) into a list to simplify processing.
- contexts = to_list(response.context)
- if not contexts:
- self.fail(msg_prefix + 'Response did not use any contexts to '
- 'render the response')
-
- # Put error(s) into a list to simplify processing.
- errors = to_list(errors)
-
- # Search all contexts for the error.
- found_formset = False
- for i, context in enumerate(contexts):
- if formset not in context:
- continue
- found_formset = True
- for err in errors:
- if field is not None:
- if field in context[formset].forms[form_index].errors:
- field_errors = context[formset].forms[form_index].errors[field]
- self.assertTrue(
- err in field_errors,
- msg_prefix + "The field '%s' on formset '%s', "
- "form %d in context %d does not contain the "
- "error '%s' (actual errors: %s)" %
- (field, formset, form_index, i, err, repr(field_errors))
- )
- elif field in context[formset].forms[form_index].fields:
- self.fail(
- msg_prefix + "The field '%s' on formset '%s', form %d in context %d contains no errors"
- % (field, formset, form_index, i)
- )
- else:
- self.fail(
- msg_prefix + "The formset '%s', form %d in context %d does not contain the field '%s'"
- % (formset, form_index, i, field)
- )
- elif form_index is not None:
- non_field_errors = context[formset].forms[form_index].non_field_errors()
- self.assertFalse(
- not non_field_errors,
- msg_prefix + "The formset '%s', form %d in context %d "
- "does not contain any non-field errors." % (formset, form_index, i)
- )
- self.assertTrue(
- err in non_field_errors,
- msg_prefix + "The formset '%s', form %d in context %d "
- "does not contain the non-field error '%s' (actual errors: %s)"
- % (formset, form_index, i, err, repr(non_field_errors))
- )
- else:
- non_form_errors = context[formset].non_form_errors()
- self.assertFalse(
- not non_form_errors,
- msg_prefix + "The formset '%s' in context %d does not "
- "contain any non-form errors." % (formset, i)
- )
- self.assertTrue(
- err in non_form_errors,
- msg_prefix + "The formset '%s' in context %d does not "
- "contain the non-form error '%s' (actual errors: %s)"
- % (formset, i, err, repr(non_form_errors))
- )
- if not found_formset:
- self.fail(msg_prefix + "The formset '%s' was not used to render the response" % formset)
-
- def _assert_template_used(self, response, template_name, msg_prefix):
-
- if response is None and template_name is None:
- raise TypeError('response and/or template_name argument must be provided')
-
- if msg_prefix:
- msg_prefix += ": "
-
- if template_name is not None and response is not None and not hasattr(response, 'templates'):
- raise ValueError(
- "assertTemplateUsed() and assertTemplateNotUsed() are only "
- "usable on responses fetched using the Django test Client."
- )
-
- if not hasattr(response, 'templates') or (response is None and template_name):
- if response:
- template_name = response
- response = None
- # use this template with context manager
- return template_name, None, msg_prefix
-
- template_names = [t.name for t in response.templates if t.name is not None]
- return None, template_names, msg_prefix
-
- def assertTemplateUsed(self, response=None, template_name=None, msg_prefix='', count=None):
- """
- Assert that the template with the provided name was used in rendering
- the response. Also usable as context manager.
- """
- context_mgr_template, template_names, msg_prefix = self._assert_template_used(
- response, template_name, msg_prefix)
-
- if context_mgr_template:
- # Use assertTemplateUsed as context manager.
- return _AssertTemplateUsedContext(self, context_mgr_template)
-
- if not template_names:
- self.fail(msg_prefix + "No templates used to render the response")
- self.assertTrue(
- template_name in template_names,
- msg_prefix + "Template '%s' was not a template used to render"
- " the response. Actual template(s) used: %s"
- % (template_name, ', '.join(template_names))
- )
-
- if count is not None:
- self.assertEqual(
- template_names.count(template_name), count,
- msg_prefix + "Template '%s' was expected to be rendered %d "
- "time(s) but was actually rendered %d time(s)."
- % (template_name, count, template_names.count(template_name))
- )
-
- def assertTemplateNotUsed(self, response=None, template_name=None, msg_prefix=''):
- """
- Assert that the template with the provided name was NOT used in
- rendering the response. Also usable as context manager.
- """
- context_mgr_template, template_names, msg_prefix = self._assert_template_used(
- response, template_name, msg_prefix
- )
- if context_mgr_template:
- # Use assertTemplateNotUsed as context manager.
- return _AssertTemplateNotUsedContext(self, context_mgr_template)
-
- self.assertFalse(
- template_name in template_names,
- msg_prefix + "Template '%s' was used unexpectedly in rendering the response" % template_name
- )
-
- @contextmanager
- def _assert_raises_or_warns_cm(self, func, cm_attr, expected_exception, expected_message):
- with func(expected_exception) as cm:
- yield cm
- self.assertIn(expected_message, str(getattr(cm, cm_attr)))
-
- def _assertFooMessage(self, func, cm_attr, expected_exception, expected_message, *args, **kwargs):
- callable_obj = None
- if args:
- callable_obj, *args = args
- cm = self._assert_raises_or_warns_cm(func, cm_attr, expected_exception, expected_message)
- # Assertion used in context manager fashion.
- if callable_obj is None:
- return cm
- # Assertion was passed a callable.
- with cm:
- callable_obj(*args, **kwargs)
-
- def assertRaisesMessage(self, expected_exception, expected_message, *args, **kwargs):
- """
- Assert that expected_message is found in the message of a raised
- exception.
-
- Args:
- expected_exception: Exception class expected to be raised.
- expected_message: expected error message string value.
- args: Function to be called and extra positional args.
- kwargs: Extra kwargs.
- """
- return self._assertFooMessage(
- self.assertRaises, 'exception', expected_exception, expected_message,
- *args, **kwargs
- )
-
- def assertWarnsMessage(self, expected_warning, expected_message, *args, **kwargs):
- """
- Same as assertRaisesMessage but for assertWarns() instead of
- assertRaises().
- """
- return self._assertFooMessage(
- self.assertWarns, 'warning', expected_warning, expected_message,
- *args, **kwargs
- )
-
- def assertFieldOutput(self, fieldclass, valid, invalid, field_args=None,
- field_kwargs=None, empty_value=''):
- """
- Assert that a form field behaves correctly with various inputs.
-
- Args:
- fieldclass: the class of the field to be tested.
- valid: a dictionary mapping valid inputs to their expected
- cleaned values.
- invalid: a dictionary mapping invalid inputs to one or more
- raised error messages.
- field_args: the args passed to instantiate the field
- field_kwargs: the kwargs passed to instantiate the field
- empty_value: the expected clean output for inputs in empty_values
- """
- if field_args is None:
- field_args = []
- if field_kwargs is None:
- field_kwargs = {}
- required = fieldclass(*field_args, **field_kwargs)
- optional = fieldclass(*field_args, **{**field_kwargs, 'required': False})
- # test valid inputs
- for input, output in valid.items():
- self.assertEqual(required.clean(input), output)
- self.assertEqual(optional.clean(input), output)
- # test invalid inputs
- for input, errors in invalid.items():
- with self.assertRaises(ValidationError) as context_manager:
- required.clean(input)
- self.assertEqual(context_manager.exception.messages, errors)
-
- with self.assertRaises(ValidationError) as context_manager:
- optional.clean(input)
- self.assertEqual(context_manager.exception.messages, errors)
- # test required inputs
- error_required = [required.error_messages['required']]
- for e in required.empty_values:
- with self.assertRaises(ValidationError) as context_manager:
- required.clean(e)
- self.assertEqual(context_manager.exception.messages, error_required)
- self.assertEqual(optional.clean(e), empty_value)
- # test that max_length and min_length are always accepted
- if issubclass(fieldclass, CharField):
- field_kwargs.update({'min_length': 2, 'max_length': 20})
- self.assertIsInstance(fieldclass(*field_args, **field_kwargs), fieldclass)
-
- def assertHTMLEqual(self, html1, html2, msg=None):
- """
- Assert that two HTML snippets are semantically the same.
- Whitespace in most cases is ignored, and attribute ordering is not
- significant. The arguments must be valid HTML.
- """
- dom1 = assert_and_parse_html(self, html1, msg, 'First argument is not valid HTML:')
- dom2 = assert_and_parse_html(self, html2, msg, 'Second argument is not valid HTML:')
-
- if dom1 != dom2:
- standardMsg = '%s != %s' % (
- safe_repr(dom1, True), safe_repr(dom2, True))
- diff = ('\n' + '\n'.join(difflib.ndiff(
- str(dom1).splitlines(), str(dom2).splitlines(),
- )))
- standardMsg = self._truncateMessage(standardMsg, diff)
- self.fail(self._formatMessage(msg, standardMsg))
-
- def assertHTMLNotEqual(self, html1, html2, msg=None):
- """Assert that two HTML snippets are not semantically equivalent."""
- dom1 = assert_and_parse_html(self, html1, msg, 'First argument is not valid HTML:')
- dom2 = assert_and_parse_html(self, html2, msg, 'Second argument is not valid HTML:')
-
- if dom1 == dom2:
- standardMsg = '%s == %s' % (
- safe_repr(dom1, True), safe_repr(dom2, True))
- self.fail(self._formatMessage(msg, standardMsg))
-
- def assertInHTML(self, needle, haystack, count=None, msg_prefix=''):
- needle = assert_and_parse_html(self, needle, None, 'First argument is not valid HTML:')
- haystack = assert_and_parse_html(self, haystack, None, 'Second argument is not valid HTML:')
- real_count = haystack.count(needle)
- if count is not None:
- self.assertEqual(
- real_count, count,
- msg_prefix + "Found %d instances of '%s' in response (expected %d)" % (real_count, needle, count)
- )
- else:
- self.assertTrue(real_count != 0, msg_prefix + "Couldn't find '%s' in response" % needle)
-
- def assertJSONEqual(self, raw, expected_data, msg=None):
- """
- Assert that the JSON fragments raw and expected_data are equal.
- Usual JSON non-significant whitespace rules apply as the heavyweight
- is delegated to the json library.
- """
- try:
- data = json.loads(raw)
- except json.JSONDecodeError:
- self.fail("First argument is not valid JSON: %r" % raw)
- if isinstance(expected_data, str):
- try:
- expected_data = json.loads(expected_data)
- except ValueError:
- self.fail("Second argument is not valid JSON: %r" % expected_data)
- self.assertEqual(data, expected_data, msg=msg)
-
- def assertJSONNotEqual(self, raw, expected_data, msg=None):
- """
- Assert that the JSON fragments raw and expected_data are not equal.
- Usual JSON non-significant whitespace rules apply as the heavyweight
- is delegated to the json library.
- """
- try:
- data = json.loads(raw)
- except json.JSONDecodeError:
- self.fail("First argument is not valid JSON: %r" % raw)
- if isinstance(expected_data, str):
- try:
- expected_data = json.loads(expected_data)
- except json.JSONDecodeError:
- self.fail("Second argument is not valid JSON: %r" % expected_data)
- self.assertNotEqual(data, expected_data, msg=msg)
-
- def assertXMLEqual(self, xml1, xml2, msg=None):
- """
- Assert that two XML snippets are semantically the same.
- Whitespace in most cases is ignored and attribute ordering is not
- significant. The arguments must be valid XML.
- """
- try:
- result = compare_xml(xml1, xml2)
- except Exception as e:
- standardMsg = 'First or second argument is not valid XML\n%s' % e
- self.fail(self._formatMessage(msg, standardMsg))
- else:
- if not result:
- standardMsg = '%s != %s' % (safe_repr(xml1, True), safe_repr(xml2, True))
- diff = ('\n' + '\n'.join(
- difflib.ndiff(xml1.splitlines(), xml2.splitlines())
- ))
- standardMsg = self._truncateMessage(standardMsg, diff)
- self.fail(self._formatMessage(msg, standardMsg))
-
- def assertXMLNotEqual(self, xml1, xml2, msg=None):
- """
- Assert that two XML snippets are not semantically equivalent.
- Whitespace in most cases is ignored and attribute ordering is not
- significant. The arguments must be valid XML.
- """
- try:
- result = compare_xml(xml1, xml2)
- except Exception as e:
- standardMsg = 'First or second argument is not valid XML\n%s' % e
- self.fail(self._formatMessage(msg, standardMsg))
- else:
- if result:
- standardMsg = '%s == %s' % (safe_repr(xml1, True), safe_repr(xml2, True))
- self.fail(self._formatMessage(msg, standardMsg))
-
-
- class _TransactionTestCaseDatabasesDescriptor:
- """Descriptor for TransactionTestCase.multi_db deprecation."""
- msg = (
- '`TransactionTestCase.multi_db` is deprecated. Databases available '
- 'during this test can be defined using %s.%s.databases.'
- )
-
- def __get__(self, instance, cls=None):
- try:
- multi_db = cls.multi_db
- except AttributeError:
- pass
- else:
- msg = self.msg % (cls.__module__, cls.__qualname__)
- warnings.warn(msg, RemovedInDjango31Warning)
- if multi_db:
- return set(connections)
- return {DEFAULT_DB_ALIAS}
-
-
- class TransactionTestCase(SimpleTestCase):
-
- # Subclasses can ask for resetting of auto increment sequence before each
- # test case
- reset_sequences = False
-
- # Subclasses can enable only a subset of apps for faster tests
- available_apps = None
-
- # Subclasses can define fixtures which will be automatically installed.
- fixtures = None
-
- databases = _TransactionTestCaseDatabasesDescriptor()
- _disallowed_database_msg = (
- 'Database %(operation)s to %(alias)r are not allowed in this test. '
- 'Add %(alias)r to %(test)s.databases to ensure proper test isolation '
- 'and silence this failure.'
- )
-
- # If transactions aren't available, Django will serialize the database
- # contents into a fixture during setup and flush and reload them
- # during teardown (as flush does not restore data from migrations).
- # This can be slow; this flag allows enabling on a per-case basis.
- serialized_rollback = False
-
- def _pre_setup(self):
- """
- Perform pre-test setup:
- * If the class has an 'available_apps' attribute, restrict the app
- registry to these applications, then fire the post_migrate signal --
- it must run with the correct set of applications for the test case.
- * If the class has a 'fixtures' attribute, install those fixtures.
- """
- super()._pre_setup()
- if self.available_apps is not None:
- apps.set_available_apps(self.available_apps)
- setting_changed.send(
- sender=settings._wrapped.__class__,
- setting='INSTALLED_APPS',
- value=self.available_apps,
- enter=True,
- )
- for db_name in self._databases_names(include_mirrors=False):
- emit_post_migrate_signal(verbosity=0, interactive=False, db=db_name)
- try:
- self._fixture_setup()
- except Exception:
- if self.available_apps is not None:
- apps.unset_available_apps()
- setting_changed.send(
- sender=settings._wrapped.__class__,
- setting='INSTALLED_APPS',
- value=settings.INSTALLED_APPS,
- enter=False,
- )
- raise
- # Clear the queries_log so that it's less likely to overflow (a single
- # test probably won't execute 9K queries). If queries_log overflows,
- # then assertNumQueries() doesn't work.
- for db_name in self._databases_names(include_mirrors=False):
- connections[db_name].queries_log.clear()
-
- @classmethod
- def _databases_names(cls, include_mirrors=True):
- # Only consider allowed database aliases, including mirrors or not.
- return [
- alias for alias in connections
- if alias in cls.databases and (
- include_mirrors or not connections[alias].settings_dict['TEST']['MIRROR']
- )
- ]
-
- def _reset_sequences(self, db_name):
- conn = connections[db_name]
- if conn.features.supports_sequence_reset:
- sql_list = conn.ops.sequence_reset_by_name_sql(
- no_style(), conn.introspection.sequence_list())
- if sql_list:
- with transaction.atomic(using=db_name):
- with conn.cursor() as cursor:
- for sql in sql_list:
- cursor.execute(sql)
-
- def _fixture_setup(self):
- for db_name in self._databases_names(include_mirrors=False):
- # Reset sequences
- if self.reset_sequences:
- self._reset_sequences(db_name)
-
- # Provide replica initial data from migrated apps, if needed.
- if self.serialized_rollback and hasattr(connections[db_name], "_test_serialized_contents"):
- if self.available_apps is not None:
- apps.unset_available_apps()
- connections[db_name].creation.deserialize_db_from_string(
- connections[db_name]._test_serialized_contents
- )
- if self.available_apps is not None:
- apps.set_available_apps(self.available_apps)
-
- if self.fixtures:
- # We have to use this slightly awkward syntax due to the fact
- # that we're using *args and **kwargs together.
- call_command('loaddata', *self.fixtures,
- **{'verbosity': 0, 'database': db_name})
-
- def _should_reload_connections(self):
- return True
-
- def _post_teardown(self):
- """
- Perform post-test things:
- * Flush the contents of the database to leave a clean slate. If the
- class has an 'available_apps' attribute, don't fire post_migrate.
- * Force-close the connection so the next test gets a clean cursor.
- """
- try:
- self._fixture_teardown()
- super()._post_teardown()
- if self._should_reload_connections():
- # Some DB cursors include SQL statements as part of cursor
- # creation. If you have a test that does a rollback, the effect
- # of these statements is lost, which can affect the operation of
- # tests (e.g., losing a timezone setting causing objects to be
- # created with the wrong time). To make sure this doesn't
- # happen, get a clean connection at the start of every test.
- for conn in connections.all():
- conn.close()
- finally:
- if self.available_apps is not None:
- apps.unset_available_apps()
- setting_changed.send(sender=settings._wrapped.__class__,
- setting='INSTALLED_APPS',
- value=settings.INSTALLED_APPS,
- enter=False)
-
- def _fixture_teardown(self):
- # Allow TRUNCATE ... CASCADE and don't emit the post_migrate signal
- # when flushing only a subset of the apps
- for db_name in self._databases_names(include_mirrors=False):
- # Flush the database
- inhibit_post_migrate = (
- self.available_apps is not None or
- ( # Inhibit the post_migrate signal when using serialized
- # rollback to avoid trying to recreate the serialized data.
- self.serialized_rollback and
- hasattr(connections[db_name], '_test_serialized_contents')
- )
- )
- call_command('flush', verbosity=0, interactive=False,
- database=db_name, reset_sequences=False,
- allow_cascade=self.available_apps is not None,
- inhibit_post_migrate=inhibit_post_migrate)
-
- def assertQuerysetEqual(self, qs, values, transform=repr, ordered=True, msg=None):
- items = map(transform, qs)
- if not ordered:
- return self.assertEqual(Counter(items), Counter(values), msg=msg)
- values = list(values)
- # For example qs.iterator() could be passed as qs, but it does not
- # have 'ordered' attribute.
- if len(values) > 1 and hasattr(qs, 'ordered') and not qs.ordered:
- raise ValueError("Trying to compare non-ordered queryset "
- "against more than one ordered values")
- return self.assertEqual(list(items), values, msg=msg)
-
- def assertNumQueries(self, num, func=None, *args, using=DEFAULT_DB_ALIAS, **kwargs):
- conn = connections[using]
-
- context = _AssertNumQueriesContext(self, num, conn)
- if func is None:
- return context
-
- with context:
- func(*args, **kwargs)
-
-
- def connections_support_transactions(aliases=None):
- """
- Return whether or not all (or specified) connections support
- transactions.
- """
- conns = connections.all() if aliases is None else (connections[alias] for alias in aliases)
- return all(conn.features.supports_transactions for conn in conns)
-
-
- class _TestCaseDatabasesDescriptor(_TransactionTestCaseDatabasesDescriptor):
- """Descriptor for TestCase.multi_db deprecation."""
- msg = (
- '`TestCase.multi_db` is deprecated. Databases available during this '
- 'test can be defined using %s.%s.databases.'
- )
-
-
- class TestCase(TransactionTestCase):
- """
- Similar to TransactionTestCase, but use `transaction.atomic()` to achieve
- test isolation.
-
- In most situations, TestCase should be preferred to TransactionTestCase as
- it allows faster execution. However, there are some situations where using
- TransactionTestCase might be necessary (e.g. testing some transactional
- behavior).
-
- On database backends with no transaction support, TestCase behaves as
- TransactionTestCase.
- """
- databases = _TestCaseDatabasesDescriptor()
-
- @classmethod
- def _enter_atomics(cls):
- """Open atomic blocks for multiple databases."""
- atomics = {}
- for db_name in cls._databases_names():
- atomics[db_name] = transaction.atomic(using=db_name)
- atomics[db_name].__enter__()
- return atomics
-
- @classmethod
- def _rollback_atomics(cls, atomics):
- """Rollback atomic blocks opened by the previous method."""
- for db_name in reversed(cls._databases_names()):
- transaction.set_rollback(True, using=db_name)
- atomics[db_name].__exit__(None, None, None)
-
- @classmethod
- def _databases_support_transactions(cls):
- return connections_support_transactions(cls.databases)
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- if not cls._databases_support_transactions():
- return
- cls.cls_atomics = cls._enter_atomics()
-
- if cls.fixtures:
- for db_name in cls._databases_names(include_mirrors=False):
- try:
- call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name})
- except Exception:
- cls._rollback_atomics(cls.cls_atomics)
- cls._remove_databases_failures()
- raise
- try:
- cls.setUpTestData()
- except Exception:
- cls._rollback_atomics(cls.cls_atomics)
- cls._remove_databases_failures()
- raise
-
- @classmethod
- def tearDownClass(cls):
- if cls._databases_support_transactions():
- cls._rollback_atomics(cls.cls_atomics)
- for conn in connections.all():
- conn.close()
- super().tearDownClass()
-
- @classmethod
- def setUpTestData(cls):
- """Load initial data for the TestCase."""
- pass
-
- def _should_reload_connections(self):
- if self._databases_support_transactions():
- return False
- return super()._should_reload_connections()
-
- def _fixture_setup(self):
- if not self._databases_support_transactions():
- # If the backend does not support transactions, we should reload
- # class data before each test
- self.setUpTestData()
- return super()._fixture_setup()
-
- assert not self.reset_sequences, 'reset_sequences cannot be used on TestCase instances'
- self.atomics = self._enter_atomics()
-
- def _fixture_teardown(self):
- if not self._databases_support_transactions():
- return super()._fixture_teardown()
- try:
- for db_name in reversed(self._databases_names()):
- if self._should_check_constraints(connections[db_name]):
- connections[db_name].check_constraints()
- finally:
- self._rollback_atomics(self.atomics)
-
- def _should_check_constraints(self, connection):
- return (
- connection.features.can_defer_constraint_checks and
- not connection.needs_rollback and connection.is_usable()
- )
-
-
- class CheckCondition:
- """Descriptor class for deferred condition checking."""
- def __init__(self, *conditions):
- self.conditions = conditions
-
- def add_condition(self, condition, reason):
- return self.__class__(*self.conditions, (condition, reason))
-
- def __get__(self, instance, cls=None):
- # Trigger access for all bases.
- if any(getattr(base, '__unittest_skip__', False) for base in cls.__bases__):
- return True
- for condition, reason in self.conditions:
- if condition():
- # Override this descriptor's value and set the skip reason.
- cls.__unittest_skip__ = True
- cls.__unittest_skip_why__ = reason
- return True
- return False
-
-
- def _deferredSkip(condition, reason, name):
- def decorator(test_func):
- nonlocal condition
- if not (isinstance(test_func, type) and
- issubclass(test_func, unittest.TestCase)):
- @wraps(test_func)
- def skip_wrapper(*args, **kwargs):
- if (args and isinstance(args[0], unittest.TestCase) and
- connection.alias not in getattr(args[0], 'databases', {})):
- raise ValueError(
- "%s cannot be used on %s as %s doesn't allow queries "
- "against the %r database." % (
- name,
- args[0],
- args[0].__class__.__qualname__,
- connection.alias,
- )
- )
- if condition():
- raise unittest.SkipTest(reason)
- return test_func(*args, **kwargs)
- test_item = skip_wrapper
- else:
- # Assume a class is decorated
- test_item = test_func
- databases = getattr(test_item, 'databases', None)
- if not databases or connection.alias not in databases:
- # Defer raising to allow importing test class's module.
- def condition():
- raise ValueError(
- "%s cannot be used on %s as it doesn't allow queries "
- "against the '%s' database." % (
- name, test_item, connection.alias,
- )
- )
- # Retrieve the possibly existing value from the class's dict to
- # avoid triggering the descriptor.
- skip = test_func.__dict__.get('__unittest_skip__')
- if isinstance(skip, CheckCondition):
- test_item.__unittest_skip__ = skip.add_condition(condition, reason)
- elif skip is not True:
- test_item.__unittest_skip__ = CheckCondition((condition, reason))
- return test_item
- return decorator
-
-
- def skipIfDBFeature(*features):
- """Skip a test if a database has at least one of the named features."""
- return _deferredSkip(
- lambda: any(getattr(connection.features, feature, False) for feature in features),
- "Database has feature(s) %s" % ", ".join(features),
- 'skipIfDBFeature',
- )
-
-
- def skipUnlessDBFeature(*features):
- """Skip a test unless a database has all the named features."""
- return _deferredSkip(
- lambda: not all(getattr(connection.features, feature, False) for feature in features),
- "Database doesn't support feature(s): %s" % ", ".join(features),
- 'skipUnlessDBFeature',
- )
-
-
- def skipUnlessAnyDBFeature(*features):
- """Skip a test unless a database has any of the named features."""
- return _deferredSkip(
- lambda: not any(getattr(connection.features, feature, False) for feature in features),
- "Database doesn't support any of the feature(s): %s" % ", ".join(features),
- 'skipUnlessAnyDBFeature',
- )
-
-
- class QuietWSGIRequestHandler(WSGIRequestHandler):
- """
- A WSGIRequestHandler that doesn't log to standard output any of the
- requests received, so as to not clutter the test result output.
- """
- def log_message(*args):
- pass
-
-
- class FSFilesHandler(WSGIHandler):
- """
- WSGI middleware that intercepts calls to a directory, as defined by one of
- the *_ROOT settings, and serves those files, publishing them under *_URL.
- """
- def __init__(self, application):
- self.application = application
- self.base_url = urlparse(self.get_base_url())
- super().__init__()
-
- def _should_handle(self, path):
- """
- Check if the path should be handled. Ignore the path if:
- * the host is provided as part of the base_url
- * the request's path isn't under the media path (or equal)
- """
- return path.startswith(self.base_url[2]) and not self.base_url[1]
-
- def file_path(self, url):
- """Return the relative path to the file on disk for the given URL."""
- relative_url = url[len(self.base_url[2]):]
- return url2pathname(relative_url)
-
- def get_response(self, request):
- from django.http import Http404
-
- if self._should_handle(request.path):
- try:
- return self.serve(request)
- except Http404:
- pass
- return super().get_response(request)
-
- def serve(self, request):
- os_rel_path = self.file_path(request.path)
- os_rel_path = posixpath.normpath(unquote(os_rel_path))
- # Emulate behavior of django.contrib.staticfiles.views.serve() when it
- # invokes staticfiles' finders functionality.
- # TODO: Modify if/when that internal API is refactored
- final_rel_path = os_rel_path.replace('\\', '/').lstrip('/')
- return serve(request, final_rel_path, document_root=self.get_base_dir())
-
- def __call__(self, environ, start_response):
- if not self._should_handle(get_path_info(environ)):
- return self.application(environ, start_response)
- return super().__call__(environ, start_response)
-
-
- class _StaticFilesHandler(FSFilesHandler):
- """
- Handler for serving static files. A private class that is meant to be used
- solely as a convenience by LiveServerThread.
- """
- def get_base_dir(self):
- return settings.STATIC_ROOT
-
- def get_base_url(self):
- return settings.STATIC_URL
-
-
- class _MediaFilesHandler(FSFilesHandler):
- """
- Handler for serving the media files. A private class that is meant to be
- used solely as a convenience by LiveServerThread.
- """
- def get_base_dir(self):
- return settings.MEDIA_ROOT
-
- def get_base_url(self):
- return settings.MEDIA_URL
-
-
- class LiveServerThread(threading.Thread):
- """Thread for running a live http server while the tests are running."""
-
- def __init__(self, host, static_handler, connections_override=None, port=0):
- self.host = host
- self.port = port
- self.is_ready = threading.Event()
- self.error = None
- self.static_handler = static_handler
- self.connections_override = connections_override
- super().__init__()
-
- def run(self):
- """
- Set up the live server and databases, and then loop over handling
- HTTP requests.
- """
- if self.connections_override:
- # Override this thread's database connections with the ones
- # provided by the main thread.
- for alias, conn in self.connections_override.items():
- connections[alias] = conn
- try:
- # Create the handler for serving static and media files
- handler = self.static_handler(_MediaFilesHandler(WSGIHandler()))
- self.httpd = self._create_server()
- # If binding to port zero, assign the port allocated by the OS.
- if self.port == 0:
- self.port = self.httpd.server_address[1]
- self.httpd.set_app(handler)
- self.is_ready.set()
- self.httpd.serve_forever()
- except Exception as e:
- self.error = e
- self.is_ready.set()
- finally:
- connections.close_all()
-
- def _create_server(self):
- return ThreadedWSGIServer((self.host, self.port), QuietWSGIRequestHandler, allow_reuse_address=False)
-
- def terminate(self):
- if hasattr(self, 'httpd'):
- # Stop the WSGI server
- self.httpd.shutdown()
- self.httpd.server_close()
- self.join()
-
-
- class LiveServerTestCase(TransactionTestCase):
- """
- Do basically the same as TransactionTestCase but also launch a live HTTP
- server in a separate thread so that the tests may use another testing
- framework, such as Selenium for example, instead of the built-in dummy
- client.
- It inherits from TransactionTestCase instead of TestCase because the
- threads don't share the same transactions (unless if using in-memory sqlite)
- and each thread needs to commit all their transactions so that the other
- thread can see the changes.
- """
- host = 'localhost'
- port = 0
- server_thread_class = LiveServerThread
- static_handler = _StaticFilesHandler
-
- @classproperty
- def live_server_url(cls):
- return 'http://%s:%s' % (cls.host, cls.server_thread.port)
-
- @classproperty
- def allowed_host(cls):
- return cls.host
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- connections_override = {}
- for conn in connections.all():
- # If using in-memory sqlite databases, pass the connections to
- # the server thread.
- if conn.vendor == 'sqlite' and conn.is_in_memory_db():
- # Explicitly enable thread-shareability for this connection
- conn.inc_thread_sharing()
- connections_override[conn.alias] = conn
-
- cls._live_server_modified_settings = modify_settings(
- ALLOWED_HOSTS={'append': cls.allowed_host},
- )
- cls._live_server_modified_settings.enable()
- cls.server_thread = cls._create_server_thread(connections_override)
- cls.server_thread.daemon = True
- cls.server_thread.start()
-
- # Wait for the live server to be ready
- cls.server_thread.is_ready.wait()
- if cls.server_thread.error:
- # Clean up behind ourselves, since tearDownClass won't get called in
- # case of errors.
- cls._tearDownClassInternal()
- raise cls.server_thread.error
-
- @classmethod
- def _create_server_thread(cls, connections_override):
- return cls.server_thread_class(
- cls.host,
- cls.static_handler,
- connections_override=connections_override,
- port=cls.port,
- )
-
- @classmethod
- def _tearDownClassInternal(cls):
- # There may not be a 'server_thread' attribute if setUpClass() for some
- # reasons has raised an exception.
- if hasattr(cls, 'server_thread'):
- # Terminate the live server's thread
- cls.server_thread.terminate()
-
- # Restore sqlite in-memory database connections' non-shareability.
- for conn in cls.server_thread.connections_override.values():
- conn.dec_thread_sharing()
-
- @classmethod
- def tearDownClass(cls):
- cls._tearDownClassInternal()
- cls._live_server_modified_settings.disable()
- super().tearDownClass()
-
-
- class SerializeMixin:
- """
- Enforce serialization of TestCases that share a common resource.
-
- Define a common 'lockfile' for each set of TestCases to serialize. This
- file must exist on the filesystem.
-
- Place it early in the MRO in order to isolate setUpClass()/tearDownClass().
- """
- lockfile = None
-
- @classmethod
- def setUpClass(cls):
- if cls.lockfile is None:
- raise ValueError(
- "{}.lockfile isn't set. Set it to a unique value "
- "in the base class.".format(cls.__name__))
- cls._lockfile = open(cls.lockfile)
- locks.lock(cls._lockfile, locks.LOCK_EX)
- super().setUpClass()
-
- @classmethod
- def tearDownClass(cls):
- super().tearDownClass()
- cls._lockfile.close()
|