Funktionierender Prototyp des Serious Games zur Vermittlung von Wissen zu Software-Engineering-Arbeitsmodellen.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

constraints.py 15KB

1 year ago

  1. from enum import Enum
  2. from django.core.exceptions import FieldError, ValidationError
  3. from django.db import connections
  4. from django.db.models.expressions import Exists, ExpressionList, F, OrderBy
  5. from django.db.models.indexes import IndexExpression
  6. from django.db.models.lookups import Exact
  7. from django.db.models.query_utils import Q
  8. from django.db.models.sql.query import Query
  9. from django.db.utils import DEFAULT_DB_ALIAS
  10. from django.utils.translation import gettext_lazy as _
  11. __all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
  12. class BaseConstraint:
  13. default_violation_error_message = _("Constraint “%(name)s” is violated.")
  14. violation_error_message = None
  15. def __init__(self, name, violation_error_message=None):
  16. self.name = name
  17. if violation_error_message is not None:
  18. self.violation_error_message = violation_error_message
  19. else:
  20. self.violation_error_message = self.default_violation_error_message
  21. @property
  22. def contains_expressions(self):
  23. return False
  24. def constraint_sql(self, model, schema_editor):
  25. raise NotImplementedError("This method must be implemented by a subclass.")
  26. def create_sql(self, model, schema_editor):
  27. raise NotImplementedError("This method must be implemented by a subclass.")
  28. def remove_sql(self, model, schema_editor):
  29. raise NotImplementedError("This method must be implemented by a subclass.")
  30. def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
  31. raise NotImplementedError("This method must be implemented by a subclass.")
  32. def get_violation_error_message(self):
  33. return self.violation_error_message % {"name": self.name}
  34. def deconstruct(self):
  35. path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
  36. path = path.replace("django.db.models.constraints", "django.db.models")
  37. kwargs = {"name": self.name}
  38. if (
  39. self.violation_error_message is not None
  40. and self.violation_error_message != self.default_violation_error_message
  41. ):
  42. kwargs["violation_error_message"] = self.violation_error_message
  43. return (path, (), kwargs)
  44. def clone(self):
  45. _, args, kwargs = self.deconstruct()
  46. return self.__class__(*args, **kwargs)
  47. class CheckConstraint(BaseConstraint):
  48. def __init__(self, *, check, name, violation_error_message=None):
  49. self.check = check
  50. if not getattr(check, "conditional", False):
  51. raise TypeError(
  52. "CheckConstraint.check must be a Q instance or boolean expression."
  53. )
  54. super().__init__(name, violation_error_message=violation_error_message)
  55. def _get_check_sql(self, model, schema_editor):
  56. query = Query(model=model, alias_cols=False)
  57. where = query.build_where(self.check)
  58. compiler = query.get_compiler(connection=schema_editor.connection)
  59. sql, params = where.as_sql(compiler, schema_editor.connection)
  60. return sql % tuple(schema_editor.quote_value(p) for p in params)
  61. def constraint_sql(self, model, schema_editor):
  62. check = self._get_check_sql(model, schema_editor)
  63. return schema_editor._check_sql(self.name, check)
  64. def create_sql(self, model, schema_editor):
  65. check = self._get_check_sql(model, schema_editor)
  66. return schema_editor._create_check_sql(model, self.name, check)
  67. def remove_sql(self, model, schema_editor):
  68. return schema_editor._delete_check_sql(model, self.name)
  69. def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
  70. against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
  71. try:
  72. if not Q(self.check).check(against, using=using):
  73. raise ValidationError(self.get_violation_error_message())
  74. except FieldError:
  75. pass
  76. def __repr__(self):
  77. return "<%s: check=%s name=%s>" % (
  78. self.__class__.__qualname__,
  79. self.check,
  80. repr(self.name),
  81. )
  82. def __eq__(self, other):
  83. if isinstance(other, CheckConstraint):
  84. return (
  85. self.name == other.name
  86. and self.check == other.check
  87. and self.violation_error_message == other.violation_error_message
  88. )
  89. return super().__eq__(other)
  90. def deconstruct(self):
  91. path, args, kwargs = super().deconstruct()
  92. kwargs["check"] = self.check
  93. return path, args, kwargs
  94. class Deferrable(Enum):
  95. DEFERRED = "deferred"
  96. IMMEDIATE = "immediate"
  97. # A similar format was proposed for Python 3.10.
  98. def __repr__(self):
  99. return f"{self.__class__.__qualname__}.{self._name_}"
  100. class UniqueConstraint(BaseConstraint):
  101. def __init__(
  102. self,
  103. *expressions,
  104. fields=(),
  105. name=None,
  106. condition=None,
  107. deferrable=None,
  108. include=None,
  109. opclasses=(),
  110. violation_error_message=None,
  111. ):
  112. if not name:
  113. raise ValueError("A unique constraint must be named.")
  114. if not expressions and not fields:
  115. raise ValueError(
  116. "At least one field or expression is required to define a "
  117. "unique constraint."
  118. )
  119. if expressions and fields:
  120. raise ValueError(
  121. "UniqueConstraint.fields and expressions are mutually exclusive."
  122. )
  123. if not isinstance(condition, (type(None), Q)):
  124. raise ValueError("UniqueConstraint.condition must be a Q instance.")
  125. if condition and deferrable:
  126. raise ValueError("UniqueConstraint with conditions cannot be deferred.")
  127. if include and deferrable:
  128. raise ValueError("UniqueConstraint with include fields cannot be deferred.")
  129. if opclasses and deferrable:
  130. raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
  131. if expressions and deferrable:
  132. raise ValueError("UniqueConstraint with expressions cannot be deferred.")
  133. if expressions and opclasses:
  134. raise ValueError(
  135. "UniqueConstraint.opclasses cannot be used with expressions. "
  136. "Use django.contrib.postgres.indexes.OpClass() instead."
  137. )
  138. if not isinstance(deferrable, (type(None), Deferrable)):
  139. raise ValueError(
  140. "UniqueConstraint.deferrable must be a Deferrable instance."
  141. )
  142. if not isinstance(include, (type(None), list, tuple)):
  143. raise ValueError("UniqueConstraint.include must be a list or tuple.")
  144. if not isinstance(opclasses, (list, tuple)):
  145. raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
  146. if opclasses and len(fields) != len(opclasses):
  147. raise ValueError(
  148. "UniqueConstraint.fields and UniqueConstraint.opclasses must "
  149. "have the same number of elements."
  150. )
  151. self.fields = tuple(fields)
  152. self.condition = condition
  153. self.deferrable = deferrable
  154. self.include = tuple(include) if include else ()
  155. self.opclasses = opclasses
  156. self.expressions = tuple(
  157. F(expression) if isinstance(expression, str) else expression
  158. for expression in expressions
  159. )
  160. super().__init__(name, violation_error_message=violation_error_message)
  161. @property
  162. def contains_expressions(self):
  163. return bool(self.expressions)
  164. def _get_condition_sql(self, model, schema_editor):
  165. if self.condition is None:
  166. return None
  167. query = Query(model=model, alias_cols=False)
  168. where = query.build_where(self.condition)
  169. compiler = query.get_compiler(connection=schema_editor.connection)
  170. sql, params = where.as_sql(compiler, schema_editor.connection)
  171. return sql % tuple(schema_editor.quote_value(p) for p in params)
  172. def _get_index_expressions(self, model, schema_editor):
  173. if not self.expressions:
  174. return None
  175. index_expressions = []
  176. for expression in self.expressions:
  177. index_expression = IndexExpression(expression)
  178. index_expression.set_wrapper_classes(schema_editor.connection)
  179. index_expressions.append(index_expression)
  180. return ExpressionList(*index_expressions).resolve_expression(
  181. Query(model, alias_cols=False),
  182. )
  183. def constraint_sql(self, model, schema_editor):
  184. fields = [model._meta.get_field(field_name) for field_name in self.fields]
  185. include = [
  186. model._meta.get_field(field_name).column for field_name in self.include
  187. ]
  188. condition = self._get_condition_sql(model, schema_editor)
  189. expressions = self._get_index_expressions(model, schema_editor)
  190. return schema_editor._unique_sql(
  191. model,
  192. fields,
  193. self.name,
  194. condition=condition,
  195. deferrable=self.deferrable,
  196. include=include,
  197. opclasses=self.opclasses,
  198. expressions=expressions,
  199. )
  200. def create_sql(self, model, schema_editor):
  201. fields = [model._meta.get_field(field_name) for field_name in self.fields]
  202. include = [
  203. model._meta.get_field(field_name).column for field_name in self.include
  204. ]
  205. condition = self._get_condition_sql(model, schema_editor)
  206. expressions = self._get_index_expressions(model, schema_editor)
  207. return schema_editor._create_unique_sql(
  208. model,
  209. fields,
  210. self.name,
  211. condition=condition,
  212. deferrable=self.deferrable,
  213. include=include,
  214. opclasses=self.opclasses,
  215. expressions=expressions,
  216. )
  217. def remove_sql(self, model, schema_editor):
  218. condition = self._get_condition_sql(model, schema_editor)
  219. include = [
  220. model._meta.get_field(field_name).column for field_name in self.include
  221. ]
  222. expressions = self._get_index_expressions(model, schema_editor)
  223. return schema_editor._delete_unique_sql(
  224. model,
  225. self.name,
  226. condition=condition,
  227. deferrable=self.deferrable,
  228. include=include,
  229. opclasses=self.opclasses,
  230. expressions=expressions,
  231. )
  232. def __repr__(self):
  233. return "<%s:%s%s%s%s%s%s%s>" % (
  234. self.__class__.__qualname__,
  235. "" if not self.fields else " fields=%s" % repr(self.fields),
  236. "" if not self.expressions else " expressions=%s" % repr(self.expressions),
  237. " name=%s" % repr(self.name),
  238. "" if self.condition is None else " condition=%s" % self.condition,
  239. "" if self.deferrable is None else " deferrable=%r" % self.deferrable,
  240. "" if not self.include else " include=%s" % repr(self.include),
  241. "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
  242. )
  243. def __eq__(self, other):
  244. if isinstance(other, UniqueConstraint):
  245. return (
  246. self.name == other.name
  247. and self.fields == other.fields
  248. and self.condition == other.condition
  249. and self.deferrable == other.deferrable
  250. and self.include == other.include
  251. and self.opclasses == other.opclasses
  252. and self.expressions == other.expressions
  253. and self.violation_error_message == other.violation_error_message
  254. )
  255. return super().__eq__(other)
  256. def deconstruct(self):
  257. path, args, kwargs = super().deconstruct()
  258. if self.fields:
  259. kwargs["fields"] = self.fields
  260. if self.condition:
  261. kwargs["condition"] = self.condition
  262. if self.deferrable:
  263. kwargs["deferrable"] = self.deferrable
  264. if self.include:
  265. kwargs["include"] = self.include
  266. if self.opclasses:
  267. kwargs["opclasses"] = self.opclasses
  268. return path, self.expressions, kwargs
  269. def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
  270. queryset = model._default_manager.using(using)
  271. if self.fields:
  272. lookup_kwargs = {}
  273. for field_name in self.fields:
  274. if exclude and field_name in exclude:
  275. return
  276. field = model._meta.get_field(field_name)
  277. lookup_value = getattr(instance, field.attname)
  278. if lookup_value is None or (
  279. lookup_value == ""
  280. and connections[using].features.interprets_empty_strings_as_nulls
  281. ):
  282. # A composite constraint containing NULL value cannot cause
  283. # a violation since NULL != NULL in SQL.
  284. return
  285. lookup_kwargs[field.name] = lookup_value
  286. queryset = queryset.filter(**lookup_kwargs)
  287. else:
  288. # Ignore constraints with excluded fields.
  289. if exclude:
  290. for expression in self.expressions:
  291. if hasattr(expression, "flatten"):
  292. for expr in expression.flatten():
  293. if isinstance(expr, F) and expr.name in exclude:
  294. return
  295. elif isinstance(expression, F) and expression.name in exclude:
  296. return
  297. replacement_map = instance._get_field_value_map(
  298. meta=model._meta, exclude=exclude
  299. )
  300. expressions = []
  301. for expr in self.expressions:
  302. # Ignore ordering.
  303. if isinstance(expr, OrderBy):
  304. expr = expr.expression
  305. expressions.append(
  306. Exact(expr, expr.replace_references(replacement_map))
  307. )
  308. queryset = queryset.filter(*expressions)
  309. model_class_pk = instance._get_pk_val(model._meta)
  310. if not instance._state.adding and model_class_pk is not None:
  311. queryset = queryset.exclude(pk=model_class_pk)
  312. if not self.condition:
  313. if queryset.exists():
  314. if self.expressions:
  315. raise ValidationError(self.get_violation_error_message())
  316. # When fields are defined, use the unique_error_message() for
  317. # backward compatibility.
  318. for model, constraints in instance.get_constraints():
  319. for constraint in constraints:
  320. if constraint is self:
  321. raise ValidationError(
  322. instance.unique_error_message(model, self.fields)
  323. )
  324. else:
  325. against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
  326. try:
  327. if (self.condition & Exists(queryset.filter(self.condition))).check(
  328. against, using=using
  329. ):
  330. raise ValidationError(self.get_violation_error_message())
  331. except FieldError:
  332. pass