|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001 |
- import asyncio
- import collections
- import logging
- import os
- import re
- import sys
- import time
- import warnings
- from contextlib import contextmanager
- from functools import wraps
- from io import StringIO
- from itertools import chain
- from types import SimpleNamespace
- from unittest import TestCase, skipIf, skipUnless
- from xml.dom.minidom import Node, parseString
-
- from django.apps import apps
- from django.apps.registry import Apps
- from django.conf import UserSettingsHolder, settings
- from django.core import mail
- from django.core.exceptions import ImproperlyConfigured
- from django.core.signals import request_started, setting_changed
- from django.db import DEFAULT_DB_ALIAS, connections, reset_queries
- from django.db.models.options import Options
- from django.template import Template
- from django.test.signals import template_rendered
- from django.urls import get_script_prefix, set_script_prefix
- from django.utils.deprecation import RemovedInDjango50Warning
- from django.utils.translation import deactivate
-
- try:
- import jinja2
- except ImportError:
- jinja2 = None
-
-
- __all__ = (
- "Approximate",
- "ContextList",
- "isolate_lru_cache",
- "get_runner",
- "CaptureQueriesContext",
- "ignore_warnings",
- "isolate_apps",
- "modify_settings",
- "override_settings",
- "override_system_checks",
- "tag",
- "requires_tz_support",
- "setup_databases",
- "setup_test_environment",
- "teardown_test_environment",
- )
-
- TZ_SUPPORT = hasattr(time, "tzset")
-
-
- class Approximate:
- def __init__(self, val, places=7):
- self.val = val
- self.places = places
-
- def __repr__(self):
- return repr(self.val)
-
- def __eq__(self, other):
- return self.val == other or round(abs(self.val - other), self.places) == 0
-
-
- class ContextList(list):
- """
- A wrapper that provides direct key access to context items contained
- in a list of context objects.
- """
-
- def __getitem__(self, key):
- if isinstance(key, str):
- for subcontext in self:
- if key in subcontext:
- return subcontext[key]
- raise KeyError(key)
- else:
- return super().__getitem__(key)
-
- def get(self, key, default=None):
- try:
- return self.__getitem__(key)
- except KeyError:
- return default
-
- def __contains__(self, key):
- try:
- self[key]
- except KeyError:
- return False
- return True
-
- def keys(self):
- """
- Flattened keys of subcontexts.
- """
- return set(chain.from_iterable(d for subcontext in self for d in subcontext))
-
-
- def instrumented_test_render(self, context):
- """
- An instrumented Template render method, providing a signal that can be
- intercepted by the test Client.
- """
- template_rendered.send(sender=self, template=self, context=context)
- return self.nodelist.render(context)
-
-
- class _TestState:
- pass
-
-
- def setup_test_environment(debug=None):
- """
- Perform global pre-test setup, such as installing the instrumented template
- renderer and setting the email backend to the locmem email backend.
- """
- if hasattr(_TestState, "saved_data"):
- # Executing this function twice would overwrite the saved values.
- raise RuntimeError(
- "setup_test_environment() was already called and can't be called "
- "again without first calling teardown_test_environment()."
- )
-
- if debug is None:
- debug = settings.DEBUG
-
- saved_data = SimpleNamespace()
- _TestState.saved_data = saved_data
-
- saved_data.allowed_hosts = settings.ALLOWED_HOSTS
- # Add the default host of the test client.
- settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, "testserver"]
-
- saved_data.debug = settings.DEBUG
- settings.DEBUG = debug
-
- saved_data.email_backend = settings.EMAIL_BACKEND
- settings.EMAIL_BACKEND = "django.core.mail.backends.locmem.EmailBackend"
-
- saved_data.template_render = Template._render
- Template._render = instrumented_test_render
-
- mail.outbox = []
-
- deactivate()
-
-
- def teardown_test_environment():
- """
- Perform any global post-test teardown, such as restoring the original
- template renderer and restoring the email sending functions.
- """
- saved_data = _TestState.saved_data
-
- settings.ALLOWED_HOSTS = saved_data.allowed_hosts
- settings.DEBUG = saved_data.debug
- settings.EMAIL_BACKEND = saved_data.email_backend
- Template._render = saved_data.template_render
-
- del _TestState.saved_data
- del mail.outbox
-
-
- def setup_databases(
- verbosity,
- interactive,
- *,
- time_keeper=None,
- keepdb=False,
- debug_sql=False,
- parallel=0,
- aliases=None,
- serialized_aliases=None,
- **kwargs,
- ):
- """Create the test databases."""
- if time_keeper is None:
- time_keeper = NullTimeKeeper()
-
- test_databases, mirrored_aliases = get_unique_databases_and_mirrors(aliases)
-
- old_names = []
-
- for db_name, aliases in test_databases.values():
- first_alias = None
- for alias in aliases:
- connection = connections[alias]
- old_names.append((connection, db_name, first_alias is None))
-
- # Actually create the database for the first connection
- if first_alias is None:
- first_alias = alias
- with time_keeper.timed(" Creating '%s'" % alias):
- # RemovedInDjango50Warning: when the deprecation ends,
- # replace with:
- # serialize_alias = (
- # serialized_aliases is None
- # or alias in serialized_aliases
- # )
- try:
- serialize_alias = connection.settings_dict["TEST"]["SERIALIZE"]
- except KeyError:
- serialize_alias = (
- serialized_aliases is None or alias in serialized_aliases
- )
- else:
- warnings.warn(
- "The SERIALIZE test database setting is "
- "deprecated as it can be inferred from the "
- "TestCase/TransactionTestCase.databases that "
- "enable the serialized_rollback feature.",
- category=RemovedInDjango50Warning,
- )
- connection.creation.create_test_db(
- verbosity=verbosity,
- autoclobber=not interactive,
- keepdb=keepdb,
- serialize=serialize_alias,
- )
- if parallel > 1:
- for index in range(parallel):
- with time_keeper.timed(" Cloning '%s'" % alias):
- connection.creation.clone_test_db(
- suffix=str(index + 1),
- verbosity=verbosity,
- keepdb=keepdb,
- )
- # Configure all other connections as mirrors of the first one
- else:
- connections[alias].creation.set_as_test_mirror(
- connections[first_alias].settings_dict
- )
-
- # Configure the test mirrors.
- for alias, mirror_alias in mirrored_aliases.items():
- connections[alias].creation.set_as_test_mirror(
- connections[mirror_alias].settings_dict
- )
-
- if debug_sql:
- for alias in connections:
- connections[alias].force_debug_cursor = True
-
- return old_names
-
-
- def iter_test_cases(tests):
- """
- Return an iterator over a test suite's unittest.TestCase objects.
-
- The tests argument can also be an iterable of TestCase objects.
- """
- for test in tests:
- if isinstance(test, str):
- # Prevent an unfriendly RecursionError that can happen with
- # strings.
- raise TypeError(
- f"Test {test!r} must be a test case or test suite not string "
- f"(was found in {tests!r})."
- )
- if isinstance(test, TestCase):
- yield test
- else:
- # Otherwise, assume it is a test suite.
- yield from iter_test_cases(test)
-
-
- def dependency_ordered(test_databases, dependencies):
- """
- Reorder test_databases into an order that honors the dependencies
- described in TEST[DEPENDENCIES].
- """
- ordered_test_databases = []
- resolved_databases = set()
-
- # Maps db signature to dependencies of all its aliases
- dependencies_map = {}
-
- # Check that no database depends on its own alias
- for sig, (_, aliases) in test_databases:
- all_deps = set()
- for alias in aliases:
- all_deps.update(dependencies.get(alias, []))
- if not all_deps.isdisjoint(aliases):
- raise ImproperlyConfigured(
- "Circular dependency: databases %r depend on each other, "
- "but are aliases." % aliases
- )
- dependencies_map[sig] = all_deps
-
- while test_databases:
- changed = False
- deferred = []
-
- # Try to find a DB that has all its dependencies met
- for signature, (db_name, aliases) in test_databases:
- if dependencies_map[signature].issubset(resolved_databases):
- resolved_databases.update(aliases)
- ordered_test_databases.append((signature, (db_name, aliases)))
- changed = True
- else:
- deferred.append((signature, (db_name, aliases)))
-
- if not changed:
- raise ImproperlyConfigured("Circular dependency in TEST[DEPENDENCIES]")
- test_databases = deferred
- return ordered_test_databases
-
-
- def get_unique_databases_and_mirrors(aliases=None):
- """
- Figure out which databases actually need to be created.
-
- Deduplicate entries in DATABASES that correspond the same database or are
- configured as test mirrors.
-
- Return two values:
- - test_databases: ordered mapping of signatures to (name, list of aliases)
- where all aliases share the same underlying database.
- - mirrored_aliases: mapping of mirror aliases to original aliases.
- """
- if aliases is None:
- aliases = connections
- mirrored_aliases = {}
- test_databases = {}
- dependencies = {}
- default_sig = connections[DEFAULT_DB_ALIAS].creation.test_db_signature()
-
- for alias in connections:
- connection = connections[alias]
- test_settings = connection.settings_dict["TEST"]
-
- if test_settings["MIRROR"]:
- # If the database is marked as a test mirror, save the alias.
- mirrored_aliases[alias] = test_settings["MIRROR"]
- elif alias in aliases:
- # Store a tuple with DB parameters that uniquely identify it.
- # If we have two aliases with the same values for that tuple,
- # we only need to create the test database once.
- item = test_databases.setdefault(
- connection.creation.test_db_signature(),
- (connection.settings_dict["NAME"], []),
- )
- # The default database must be the first because data migrations
- # use the default alias by default.
- if alias == DEFAULT_DB_ALIAS:
- item[1].insert(0, alias)
- else:
- item[1].append(alias)
-
- if "DEPENDENCIES" in test_settings:
- dependencies[alias] = test_settings["DEPENDENCIES"]
- else:
- if (
- alias != DEFAULT_DB_ALIAS
- and connection.creation.test_db_signature() != default_sig
- ):
- dependencies[alias] = test_settings.get(
- "DEPENDENCIES", [DEFAULT_DB_ALIAS]
- )
-
- test_databases = dict(dependency_ordered(test_databases.items(), dependencies))
- return test_databases, mirrored_aliases
-
-
- def teardown_databases(old_config, verbosity, parallel=0, keepdb=False):
- """Destroy all the non-mirror databases."""
- for connection, old_name, destroy in old_config:
- if destroy:
- if parallel > 1:
- for index in range(parallel):
- connection.creation.destroy_test_db(
- suffix=str(index + 1),
- verbosity=verbosity,
- keepdb=keepdb,
- )
- connection.creation.destroy_test_db(old_name, verbosity, keepdb)
-
-
- def get_runner(settings, test_runner_class=None):
- test_runner_class = test_runner_class or settings.TEST_RUNNER
- test_path = test_runner_class.split(".")
- # Allow for relative paths
- if len(test_path) > 1:
- test_module_name = ".".join(test_path[:-1])
- else:
- test_module_name = "."
- test_module = __import__(test_module_name, {}, {}, test_path[-1])
- return getattr(test_module, test_path[-1])
-
-
- class TestContextDecorator:
- """
- A base class that can either be used as a context manager during tests
- or as a test function or unittest.TestCase subclass decorator to perform
- temporary alterations.
-
- `attr_name`: attribute assigned the return value of enable() if used as
- a class decorator.
-
- `kwarg_name`: keyword argument passing the return value of enable() if
- used as a function decorator.
- """
-
- def __init__(self, attr_name=None, kwarg_name=None):
- self.attr_name = attr_name
- self.kwarg_name = kwarg_name
-
- def enable(self):
- raise NotImplementedError
-
- def disable(self):
- raise NotImplementedError
-
- def __enter__(self):
- return self.enable()
-
- def __exit__(self, exc_type, exc_value, traceback):
- self.disable()
-
- def decorate_class(self, cls):
- if issubclass(cls, TestCase):
- decorated_setUp = cls.setUp
-
- def setUp(inner_self):
- context = self.enable()
- inner_self.addCleanup(self.disable)
- if self.attr_name:
- setattr(inner_self, self.attr_name, context)
- decorated_setUp(inner_self)
-
- cls.setUp = setUp
- return cls
- raise TypeError("Can only decorate subclasses of unittest.TestCase")
-
- def decorate_callable(self, func):
- if asyncio.iscoroutinefunction(func):
- # If the inner function is an async function, we must execute async
- # as well so that the `with` statement executes at the right time.
- @wraps(func)
- async def inner(*args, **kwargs):
- with self as context:
- if self.kwarg_name:
- kwargs[self.kwarg_name] = context
- return await func(*args, **kwargs)
-
- else:
-
- @wraps(func)
- def inner(*args, **kwargs):
- with self as context:
- if self.kwarg_name:
- kwargs[self.kwarg_name] = context
- return func(*args, **kwargs)
-
- return inner
-
- def __call__(self, decorated):
- if isinstance(decorated, type):
- return self.decorate_class(decorated)
- elif callable(decorated):
- return self.decorate_callable(decorated)
- raise TypeError("Cannot decorate object of type %s" % type(decorated))
-
-
- class override_settings(TestContextDecorator):
- """
- Act as either a decorator or a context manager. If it's a decorator, take a
- function and return a wrapped function. If it's a contextmanager, use it
- with the ``with`` statement. In either event, entering/exiting are called
- before and after, respectively, the function/block is executed.
- """
-
- enable_exception = None
-
- def __init__(self, **kwargs):
- self.options = kwargs
- super().__init__()
-
- def enable(self):
- # Keep this code at the beginning to leave the settings unchanged
- # in case it raises an exception because INSTALLED_APPS is invalid.
- if "INSTALLED_APPS" in self.options:
- try:
- apps.set_installed_apps(self.options["INSTALLED_APPS"])
- except Exception:
- apps.unset_installed_apps()
- raise
- override = UserSettingsHolder(settings._wrapped)
- for key, new_value in self.options.items():
- setattr(override, key, new_value)
- self.wrapped = settings._wrapped
- settings._wrapped = override
- for key, new_value in self.options.items():
- try:
- setting_changed.send(
- sender=settings._wrapped.__class__,
- setting=key,
- value=new_value,
- enter=True,
- )
- except Exception as exc:
- self.enable_exception = exc
- self.disable()
-
- def disable(self):
- if "INSTALLED_APPS" in self.options:
- apps.unset_installed_apps()
- settings._wrapped = self.wrapped
- del self.wrapped
- responses = []
- for key in self.options:
- new_value = getattr(settings, key, None)
- responses_for_setting = setting_changed.send_robust(
- sender=settings._wrapped.__class__,
- setting=key,
- value=new_value,
- enter=False,
- )
- responses.extend(responses_for_setting)
- if self.enable_exception is not None:
- exc = self.enable_exception
- self.enable_exception = None
- raise exc
- for _, response in responses:
- if isinstance(response, Exception):
- raise response
-
- def save_options(self, test_func):
- if test_func._overridden_settings is None:
- test_func._overridden_settings = self.options
- else:
- # Duplicate dict to prevent subclasses from altering their parent.
- test_func._overridden_settings = {
- **test_func._overridden_settings,
- **self.options,
- }
-
- def decorate_class(self, cls):
- from django.test import SimpleTestCase
-
- if not issubclass(cls, SimpleTestCase):
- raise ValueError(
- "Only subclasses of Django SimpleTestCase can be decorated "
- "with override_settings"
- )
- self.save_options(cls)
- return cls
-
-
- class modify_settings(override_settings):
- """
- Like override_settings, but makes it possible to append, prepend, or remove
- items instead of redefining the entire list.
- """
-
- def __init__(self, *args, **kwargs):
- if args:
- # Hack used when instantiating from SimpleTestCase.setUpClass.
- assert not kwargs
- self.operations = args[0]
- else:
- assert not args
- self.operations = list(kwargs.items())
- super(override_settings, self).__init__()
-
- def save_options(self, test_func):
- if test_func._modified_settings is None:
- test_func._modified_settings = self.operations
- else:
- # Duplicate list to prevent subclasses from altering their parent.
- test_func._modified_settings = (
- list(test_func._modified_settings) + self.operations
- )
-
- def enable(self):
- self.options = {}
- for name, operations in self.operations:
- try:
- # When called from SimpleTestCase.setUpClass, values may be
- # overridden several times; cumulate changes.
- value = self.options[name]
- except KeyError:
- value = list(getattr(settings, name, []))
- for action, items in operations.items():
- # items my be a single value or an iterable.
- if isinstance(items, str):
- items = [items]
- if action == "append":
- value = value + [item for item in items if item not in value]
- elif action == "prepend":
- value = [item for item in items if item not in value] + value
- elif action == "remove":
- value = [item for item in value if item not in items]
- else:
- raise ValueError("Unsupported action: %s" % action)
- self.options[name] = value
- super().enable()
-
-
- class override_system_checks(TestContextDecorator):
- """
- Act as a decorator. Override list of registered system checks.
- Useful when you override `INSTALLED_APPS`, e.g. if you exclude `auth` app,
- you also need to exclude its system checks.
- """
-
- def __init__(self, new_checks, deployment_checks=None):
- from django.core.checks.registry import registry
-
- self.registry = registry
- self.new_checks = new_checks
- self.deployment_checks = deployment_checks
- super().__init__()
-
- def enable(self):
- self.old_checks = self.registry.registered_checks
- self.registry.registered_checks = set()
- for check in self.new_checks:
- self.registry.register(check, *getattr(check, "tags", ()))
- self.old_deployment_checks = self.registry.deployment_checks
- if self.deployment_checks is not None:
- self.registry.deployment_checks = set()
- for check in self.deployment_checks:
- self.registry.register(check, *getattr(check, "tags", ()), deploy=True)
-
- def disable(self):
- self.registry.registered_checks = self.old_checks
- self.registry.deployment_checks = self.old_deployment_checks
-
-
- def compare_xml(want, got):
- """
- Try to do a 'xml-comparison' of want and got. Plain string comparison
- doesn't always work because, for example, attribute ordering should not be
- important. Ignore comment nodes, processing instructions, document type
- node, and leading and trailing whitespaces.
-
- Based on https://github.com/lxml/lxml/blob/master/src/lxml/doctestcompare.py
- """
- _norm_whitespace_re = re.compile(r"[ \t\n][ \t\n]+")
-
- def norm_whitespace(v):
- return _norm_whitespace_re.sub(" ", v)
-
- def child_text(element):
- return "".join(
- c.data for c in element.childNodes if c.nodeType == Node.TEXT_NODE
- )
-
- def children(element):
- return [c for c in element.childNodes if c.nodeType == Node.ELEMENT_NODE]
-
- def norm_child_text(element):
- return norm_whitespace(child_text(element))
-
- def attrs_dict(element):
- return dict(element.attributes.items())
-
- def check_element(want_element, got_element):
- if want_element.tagName != got_element.tagName:
- return False
- if norm_child_text(want_element) != norm_child_text(got_element):
- return False
- if attrs_dict(want_element) != attrs_dict(got_element):
- return False
- want_children = children(want_element)
- got_children = children(got_element)
- if len(want_children) != len(got_children):
- return False
- return all(
- check_element(want, got) for want, got in zip(want_children, got_children)
- )
-
- def first_node(document):
- for node in document.childNodes:
- if node.nodeType not in (
- Node.COMMENT_NODE,
- Node.DOCUMENT_TYPE_NODE,
- Node.PROCESSING_INSTRUCTION_NODE,
- ):
- return node
-
- want = want.strip().replace("\\n", "\n")
- got = got.strip().replace("\\n", "\n")
-
- # If the string is not a complete xml document, we may need to add a
- # root element. This allow us to compare fragments, like "<foo/><bar/>"
- if not want.startswith("<?xml"):
- wrapper = "<root>%s</root>"
- want = wrapper % want
- got = wrapper % got
-
- # Parse the want and got strings, and compare the parsings.
- want_root = first_node(parseString(want))
- got_root = first_node(parseString(got))
-
- return check_element(want_root, got_root)
-
-
- class CaptureQueriesContext:
- """
- Context manager that captures queries executed by the specified connection.
- """
-
- def __init__(self, connection):
- self.connection = connection
-
- def __iter__(self):
- return iter(self.captured_queries)
-
- def __getitem__(self, index):
- return self.captured_queries[index]
-
- def __len__(self):
- return len(self.captured_queries)
-
- @property
- def captured_queries(self):
- return self.connection.queries[self.initial_queries : self.final_queries]
-
- def __enter__(self):
- self.force_debug_cursor = self.connection.force_debug_cursor
- self.connection.force_debug_cursor = True
- # Run any initialization queries if needed so that they won't be
- # included as part of the count.
- self.connection.ensure_connection()
- self.initial_queries = len(self.connection.queries_log)
- self.final_queries = None
- request_started.disconnect(reset_queries)
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- self.connection.force_debug_cursor = self.force_debug_cursor
- request_started.connect(reset_queries)
- if exc_type is not None:
- return
- self.final_queries = len(self.connection.queries_log)
-
-
- class ignore_warnings(TestContextDecorator):
- def __init__(self, **kwargs):
- self.ignore_kwargs = kwargs
- if "message" in self.ignore_kwargs or "module" in self.ignore_kwargs:
- self.filter_func = warnings.filterwarnings
- else:
- self.filter_func = warnings.simplefilter
- super().__init__()
-
- def enable(self):
- self.catch_warnings = warnings.catch_warnings()
- self.catch_warnings.__enter__()
- self.filter_func("ignore", **self.ignore_kwargs)
-
- def disable(self):
- self.catch_warnings.__exit__(*sys.exc_info())
-
-
- # On OSes that don't provide tzset (Windows), we can't set the timezone
- # in which the program runs. As a consequence, we must skip tests that
- # don't enforce a specific timezone (with timezone.override or equivalent),
- # or attempt to interpret naive datetimes in the default timezone.
-
- requires_tz_support = skipUnless(
- TZ_SUPPORT,
- "This test relies on the ability to run a program in an arbitrary "
- "time zone, but your operating system isn't able to do that.",
- )
-
-
- @contextmanager
- def extend_sys_path(*paths):
- """Context manager to temporarily add paths to sys.path."""
- _orig_sys_path = sys.path[:]
- sys.path.extend(paths)
- try:
- yield
- finally:
- sys.path = _orig_sys_path
-
-
- @contextmanager
- def isolate_lru_cache(lru_cache_object):
- """Clear the cache of an LRU cache object on entering and exiting."""
- lru_cache_object.cache_clear()
- try:
- yield
- finally:
- lru_cache_object.cache_clear()
-
-
- @contextmanager
- def captured_output(stream_name):
- """Return a context manager used by captured_stdout/stdin/stderr
- that temporarily replaces the sys stream *stream_name* with a StringIO.
-
- Note: This function and the following ``captured_std*`` are copied
- from CPython's ``test.support`` module."""
- orig_stdout = getattr(sys, stream_name)
- setattr(sys, stream_name, StringIO())
- try:
- yield getattr(sys, stream_name)
- finally:
- setattr(sys, stream_name, orig_stdout)
-
-
- def captured_stdout():
- """Capture the output of sys.stdout:
-
- with captured_stdout() as stdout:
- print("hello")
- self.assertEqual(stdout.getvalue(), "hello\n")
- """
- return captured_output("stdout")
-
-
- def captured_stderr():
- """Capture the output of sys.stderr:
-
- with captured_stderr() as stderr:
- print("hello", file=sys.stderr)
- self.assertEqual(stderr.getvalue(), "hello\n")
- """
- return captured_output("stderr")
-
-
- def captured_stdin():
- """Capture the input to sys.stdin:
-
- with captured_stdin() as stdin:
- stdin.write('hello\n')
- stdin.seek(0)
- # call test code that consumes from sys.stdin
- captured = input()
- self.assertEqual(captured, "hello")
- """
- return captured_output("stdin")
-
-
- @contextmanager
- def freeze_time(t):
- """
- Context manager to temporarily freeze time.time(). This temporarily
- modifies the time function of the time module. Modules which import the
- time function directly (e.g. `from time import time`) won't be affected
- This isn't meant as a public API, but helps reduce some repetitive code in
- Django's test suite.
- """
- _real_time = time.time
- time.time = lambda: t
- try:
- yield
- finally:
- time.time = _real_time
-
-
- def require_jinja2(test_func):
- """
- Decorator to enable a Jinja2 template engine in addition to the regular
- Django template engine for a test or skip it if Jinja2 isn't available.
- """
- test_func = skipIf(jinja2 is None, "this test requires jinja2")(test_func)
- return override_settings(
- TEMPLATES=[
- {
- "BACKEND": "django.template.backends.django.DjangoTemplates",
- "APP_DIRS": True,
- },
- {
- "BACKEND": "django.template.backends.jinja2.Jinja2",
- "APP_DIRS": True,
- "OPTIONS": {"keep_trailing_newline": True},
- },
- ]
- )(test_func)
-
-
- class override_script_prefix(TestContextDecorator):
- """Decorator or context manager to temporary override the script prefix."""
-
- def __init__(self, prefix):
- self.prefix = prefix
- super().__init__()
-
- def enable(self):
- self.old_prefix = get_script_prefix()
- set_script_prefix(self.prefix)
-
- def disable(self):
- set_script_prefix(self.old_prefix)
-
-
- class LoggingCaptureMixin:
- """
- Capture the output from the 'django' logger and store it on the class's
- logger_output attribute.
- """
-
- def setUp(self):
- self.logger = logging.getLogger("django")
- self.old_stream = self.logger.handlers[0].stream
- self.logger_output = StringIO()
- self.logger.handlers[0].stream = self.logger_output
-
- def tearDown(self):
- self.logger.handlers[0].stream = self.old_stream
-
-
- class isolate_apps(TestContextDecorator):
- """
- Act as either a decorator or a context manager to register models defined
- in its wrapped context to an isolated registry.
-
- The list of installed apps the isolated registry should contain must be
- passed as arguments.
-
- Two optional keyword arguments can be specified:
-
- `attr_name`: attribute assigned the isolated registry if used as a class
- decorator.
-
- `kwarg_name`: keyword argument passing the isolated registry if used as a
- function decorator.
- """
-
- def __init__(self, *installed_apps, **kwargs):
- self.installed_apps = installed_apps
- super().__init__(**kwargs)
-
- def enable(self):
- self.old_apps = Options.default_apps
- apps = Apps(self.installed_apps)
- setattr(Options, "default_apps", apps)
- return apps
-
- def disable(self):
- setattr(Options, "default_apps", self.old_apps)
-
-
- class TimeKeeper:
- def __init__(self):
- self.records = collections.defaultdict(list)
-
- @contextmanager
- def timed(self, name):
- self.records[name]
- start_time = time.perf_counter()
- try:
- yield
- finally:
- end_time = time.perf_counter() - start_time
- self.records[name].append(end_time)
-
- def print_results(self):
- for name, end_times in self.records.items():
- for record_time in end_times:
- record = "%s took %.3fs" % (name, record_time)
- sys.stderr.write(record + os.linesep)
-
-
- class NullTimeKeeper:
- @contextmanager
- def timed(self, name):
- yield
-
- def print_results(self):
- pass
-
-
- def tag(*tags):
- """Decorator to add tags to a test class or method."""
-
- def decorator(obj):
- if hasattr(obj, "tags"):
- obj.tags = obj.tags.union(tags)
- else:
- setattr(obj, "tags", set(tags))
- return obj
-
- return decorator
-
-
- @contextmanager
- def register_lookup(field, *lookups, lookup_name=None):
- """
- Context manager to temporarily register lookups on a model field using
- lookup_name (or the lookup's lookup_name if not provided).
- """
- try:
- for lookup in lookups:
- field.register_lookup(lookup, lookup_name)
- yield
- finally:
- for lookup in lookups:
- field._unregister_lookup(lookup, lookup_name)
|