You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

constraints.py 4.6KB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from django.db.models.query_utils import Q
  2. from django.db.models.sql.query import Query
  3. __all__ = ['CheckConstraint', 'UniqueConstraint']
  4. class BaseConstraint:
  5. def __init__(self, name):
  6. self.name = name
  7. def constraint_sql(self, model, schema_editor):
  8. raise NotImplementedError('This method must be implemented by a subclass.')
  9. def create_sql(self, model, schema_editor):
  10. raise NotImplementedError('This method must be implemented by a subclass.')
  11. def remove_sql(self, model, schema_editor):
  12. raise NotImplementedError('This method must be implemented by a subclass.')
  13. def deconstruct(self):
  14. path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
  15. path = path.replace('django.db.models.constraints', 'django.db.models')
  16. return (path, (), {'name': self.name})
  17. def clone(self):
  18. _, args, kwargs = self.deconstruct()
  19. return self.__class__(*args, **kwargs)
  20. class CheckConstraint(BaseConstraint):
  21. def __init__(self, *, check, name):
  22. self.check = check
  23. super().__init__(name)
  24. def _get_check_sql(self, model, schema_editor):
  25. query = Query(model=model)
  26. where = query.build_where(self.check)
  27. compiler = query.get_compiler(connection=schema_editor.connection)
  28. sql, params = where.as_sql(compiler, schema_editor.connection)
  29. return sql % tuple(schema_editor.quote_value(p) for p in params)
  30. def constraint_sql(self, model, schema_editor):
  31. check = self._get_check_sql(model, schema_editor)
  32. return schema_editor._check_sql(self.name, check)
  33. def create_sql(self, model, schema_editor):
  34. check = self._get_check_sql(model, schema_editor)
  35. return schema_editor._create_check_sql(model, self.name, check)
  36. def remove_sql(self, model, schema_editor):
  37. return schema_editor._delete_check_sql(model, self.name)
  38. def __repr__(self):
  39. return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name)
  40. def __eq__(self, other):
  41. return (
  42. isinstance(other, CheckConstraint) and
  43. self.name == other.name and
  44. self.check == other.check
  45. )
  46. def deconstruct(self):
  47. path, args, kwargs = super().deconstruct()
  48. kwargs['check'] = self.check
  49. return path, args, kwargs
  50. class UniqueConstraint(BaseConstraint):
  51. def __init__(self, *, fields, name, condition=None):
  52. if not fields:
  53. raise ValueError('At least one field is required to define a unique constraint.')
  54. if not isinstance(condition, (type(None), Q)):
  55. raise ValueError('UniqueConstraint.condition must be a Q instance.')
  56. self.fields = tuple(fields)
  57. self.condition = condition
  58. super().__init__(name)
  59. def _get_condition_sql(self, model, schema_editor):
  60. if self.condition is None:
  61. return None
  62. query = Query(model=model)
  63. where = query.build_where(self.condition)
  64. compiler = query.get_compiler(connection=schema_editor.connection)
  65. sql, params = where.as_sql(compiler, schema_editor.connection)
  66. return sql % tuple(schema_editor.quote_value(p) for p in params)
  67. def constraint_sql(self, model, schema_editor):
  68. fields = [model._meta.get_field(field_name).column for field_name in self.fields]
  69. condition = self._get_condition_sql(model, schema_editor)
  70. return schema_editor._unique_sql(model, fields, self.name, condition=condition)
  71. def create_sql(self, model, schema_editor):
  72. fields = [model._meta.get_field(field_name).column for field_name in self.fields]
  73. condition = self._get_condition_sql(model, schema_editor)
  74. return schema_editor._create_unique_sql(model, fields, self.name, condition=condition)
  75. def remove_sql(self, model, schema_editor):
  76. condition = self._get_condition_sql(model, schema_editor)
  77. return schema_editor._delete_unique_sql(model, self.name, condition=condition)
  78. def __repr__(self):
  79. return '<%s: fields=%r name=%r%s>' % (
  80. self.__class__.__name__, self.fields, self.name,
  81. '' if self.condition is None else ' condition=%s' % self.condition,
  82. )
  83. def __eq__(self, other):
  84. return (
  85. isinstance(other, UniqueConstraint) and
  86. self.name == other.name and
  87. self.fields == other.fields and
  88. self.condition == other.condition
  89. )
  90. def deconstruct(self):
  91. path, args, kwargs = super().deconstruct()
  92. kwargs['fields'] = self.fields
  93. if self.condition:
  94. kwargs['condition'] = self.condition
  95. return path, args, kwargs