from enum import Enum from django.core.exceptions import FieldError, ValidationError from django.db import connections from django.db.models.expressions import Exists, ExpressionList, F, OrderBy from django.db.models.indexes import IndexExpression from django.db.models.lookups import Exact from django.db.models.query_utils import Q from django.db.models.sql.query import Query from django.db.utils import DEFAULT_DB_ALIAS from django.utils.translation import gettext_lazy as _ __all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"] class BaseConstraint: default_violation_error_message = _("Constraint ā€œ%(name)sā€ is violated.") violation_error_message = None def __init__(self, name, violation_error_message=None): self.name = name if violation_error_message is not None: self.violation_error_message = violation_error_message else: self.violation_error_message = self.default_violation_error_message @property def contains_expressions(self): return False def constraint_sql(self, model, schema_editor): raise NotImplementedError("This method must be implemented by a subclass.") def create_sql(self, model, schema_editor): raise NotImplementedError("This method must be implemented by a subclass.") def remove_sql(self, model, schema_editor): raise NotImplementedError("This method must be implemented by a subclass.") def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): raise NotImplementedError("This method must be implemented by a subclass.") def get_violation_error_message(self): return self.violation_error_message % {"name": self.name} def deconstruct(self): path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) path = path.replace("django.db.models.constraints", "django.db.models") kwargs = {"name": self.name} if ( self.violation_error_message is not None and self.violation_error_message != self.default_violation_error_message ): kwargs["violation_error_message"] = self.violation_error_message return (path, (), kwargs) def clone(self): _, args, kwargs = self.deconstruct() return self.__class__(*args, **kwargs) class CheckConstraint(BaseConstraint): def __init__(self, *, check, name, violation_error_message=None): self.check = check if not getattr(check, "conditional", False): raise TypeError( "CheckConstraint.check must be a Q instance or boolean expression." ) super().__init__(name, violation_error_message=violation_error_message) def _get_check_sql(self, model, schema_editor): query = Query(model=model, alias_cols=False) where = query.build_where(self.check) compiler = query.get_compiler(connection=schema_editor.connection) sql, params = where.as_sql(compiler, schema_editor.connection) return sql % tuple(schema_editor.quote_value(p) for p in params) def constraint_sql(self, model, schema_editor): check = self._get_check_sql(model, schema_editor) return schema_editor._check_sql(self.name, check) def create_sql(self, model, schema_editor): check = self._get_check_sql(model, schema_editor) return schema_editor._create_check_sql(model, self.name, check) def remove_sql(self, model, schema_editor): return schema_editor._delete_check_sql(model, self.name) def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): against = instance._get_field_value_map(meta=model._meta, exclude=exclude) try: if not Q(self.check).check(against, using=using): raise ValidationError(self.get_violation_error_message()) except FieldError: pass def __repr__(self): return "<%s: check=%s name=%s>" % ( self.__class__.__qualname__, self.check, repr(self.name), ) def __eq__(self, other): if isinstance(other, CheckConstraint): return ( self.name == other.name and self.check == other.check and self.violation_error_message == other.violation_error_message ) return super().__eq__(other) def deconstruct(self): path, args, kwargs = super().deconstruct() kwargs["check"] = self.check return path, args, kwargs class Deferrable(Enum): DEFERRED = "deferred" IMMEDIATE = "immediate" # A similar format was proposed for Python 3.10. def __repr__(self): return f"{self.__class__.__qualname__}.{self._name_}" class UniqueConstraint(BaseConstraint): def __init__( self, *expressions, fields=(), name=None, condition=None, deferrable=None, include=None, opclasses=(), violation_error_message=None, ): if not name: raise ValueError("A unique constraint must be named.") if not expressions and not fields: raise ValueError( "At least one field or expression is required to define a " "unique constraint." ) if expressions and fields: raise ValueError( "UniqueConstraint.fields and expressions are mutually exclusive." ) if not isinstance(condition, (type(None), Q)): raise ValueError("UniqueConstraint.condition must be a Q instance.") if condition and deferrable: raise ValueError("UniqueConstraint with conditions cannot be deferred.") if include and deferrable: raise ValueError("UniqueConstraint with include fields cannot be deferred.") if opclasses and deferrable: raise ValueError("UniqueConstraint with opclasses cannot be deferred.") if expressions and deferrable: raise ValueError("UniqueConstraint with expressions cannot be deferred.") if expressions and opclasses: raise ValueError( "UniqueConstraint.opclasses cannot be used with expressions. " "Use django.contrib.postgres.indexes.OpClass() instead." ) if not isinstance(deferrable, (type(None), Deferrable)): raise ValueError( "UniqueConstraint.deferrable must be a Deferrable instance." ) if not isinstance(include, (type(None), list, tuple)): raise ValueError("UniqueConstraint.include must be a list or tuple.") if not isinstance(opclasses, (list, tuple)): raise ValueError("UniqueConstraint.opclasses must be a list or tuple.") if opclasses and len(fields) != len(opclasses): raise ValueError( "UniqueConstraint.fields and UniqueConstraint.opclasses must " "have the same number of elements." ) self.fields = tuple(fields) self.condition = condition self.deferrable = deferrable self.include = tuple(include) if include else () self.opclasses = opclasses self.expressions = tuple( F(expression) if isinstance(expression, str) else expression for expression in expressions ) super().__init__(name, violation_error_message=violation_error_message) @property def contains_expressions(self): return bool(self.expressions) def _get_condition_sql(self, model, schema_editor): if self.condition is None: return None query = Query(model=model, alias_cols=False) where = query.build_where(self.condition) compiler = query.get_compiler(connection=schema_editor.connection) sql, params = where.as_sql(compiler, schema_editor.connection) return sql % tuple(schema_editor.quote_value(p) for p in params) def _get_index_expressions(self, model, schema_editor): if not self.expressions: return None index_expressions = [] for expression in self.expressions: index_expression = IndexExpression(expression) index_expression.set_wrapper_classes(schema_editor.connection) index_expressions.append(index_expression) return ExpressionList(*index_expressions).resolve_expression( Query(model, alias_cols=False), ) def constraint_sql(self, model, schema_editor): fields = [model._meta.get_field(field_name) for field_name in self.fields] include = [ model._meta.get_field(field_name).column for field_name in self.include ] condition = self._get_condition_sql(model, schema_editor) expressions = self._get_index_expressions(model, schema_editor) return schema_editor._unique_sql( model, fields, self.name, condition=condition, deferrable=self.deferrable, include=include, opclasses=self.opclasses, expressions=expressions, ) def create_sql(self, model, schema_editor): fields = [model._meta.get_field(field_name) for field_name in self.fields] include = [ model._meta.get_field(field_name).column for field_name in self.include ] condition = self._get_condition_sql(model, schema_editor) expressions = self._get_index_expressions(model, schema_editor) return schema_editor._create_unique_sql( model, fields, self.name, condition=condition, deferrable=self.deferrable, include=include, opclasses=self.opclasses, expressions=expressions, ) def remove_sql(self, model, schema_editor): condition = self._get_condition_sql(model, schema_editor) include = [ model._meta.get_field(field_name).column for field_name in self.include ] expressions = self._get_index_expressions(model, schema_editor) return schema_editor._delete_unique_sql( model, self.name, condition=condition, deferrable=self.deferrable, include=include, opclasses=self.opclasses, expressions=expressions, ) def __repr__(self): return "<%s:%s%s%s%s%s%s%s>" % ( self.__class__.__qualname__, "" if not self.fields else " fields=%s" % repr(self.fields), "" if not self.expressions else " expressions=%s" % repr(self.expressions), " name=%s" % repr(self.name), "" if self.condition is None else " condition=%s" % self.condition, "" if self.deferrable is None else " deferrable=%r" % self.deferrable, "" if not self.include else " include=%s" % repr(self.include), "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses), ) def __eq__(self, other): if isinstance(other, UniqueConstraint): return ( self.name == other.name and self.fields == other.fields and self.condition == other.condition and self.deferrable == other.deferrable and self.include == other.include and self.opclasses == other.opclasses and self.expressions == other.expressions and self.violation_error_message == other.violation_error_message ) return super().__eq__(other) def deconstruct(self): path, args, kwargs = super().deconstruct() if self.fields: kwargs["fields"] = self.fields if self.condition: kwargs["condition"] = self.condition if self.deferrable: kwargs["deferrable"] = self.deferrable if self.include: kwargs["include"] = self.include if self.opclasses: kwargs["opclasses"] = self.opclasses return path, self.expressions, kwargs def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): queryset = model._default_manager.using(using) if self.fields: lookup_kwargs = {} for field_name in self.fields: if exclude and field_name in exclude: return field = model._meta.get_field(field_name) lookup_value = getattr(instance, field.attname) if lookup_value is None or ( lookup_value == "" and connections[using].features.interprets_empty_strings_as_nulls ): # A composite constraint containing NULL value cannot cause # a violation since NULL != NULL in SQL. return lookup_kwargs[field.name] = lookup_value queryset = queryset.filter(**lookup_kwargs) else: # Ignore constraints with excluded fields. if exclude: for expression in self.expressions: if hasattr(expression, "flatten"): for expr in expression.flatten(): if isinstance(expr, F) and expr.name in exclude: return elif isinstance(expression, F) and expression.name in exclude: return replacement_map = instance._get_field_value_map( meta=model._meta, exclude=exclude ) expressions = [] for expr in self.expressions: # Ignore ordering. if isinstance(expr, OrderBy): expr = expr.expression expressions.append( Exact(expr, expr.replace_references(replacement_map)) ) queryset = queryset.filter(*expressions) model_class_pk = instance._get_pk_val(model._meta) if not instance._state.adding and model_class_pk is not None: queryset = queryset.exclude(pk=model_class_pk) if not self.condition: if queryset.exists(): if self.expressions: raise ValidationError(self.get_violation_error_message()) # When fields are defined, use the unique_error_message() for # backward compatibility. for model, constraints in instance.get_constraints(): for constraint in constraints: if constraint is self: raise ValidationError( instance.unique_error_message(model, self.fields) ) else: against = instance._get_field_value_map(meta=model._meta, exclude=exclude) try: if (self.condition & Exists(queryset.filter(self.condition))).check( against, using=using ): raise ValidationError(self.get_violation_error_message()) except FieldError: pass