Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

array.py 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import json
  2. from django.contrib.postgres import lookups
  3. from django.contrib.postgres.forms import SimpleArrayField
  4. from django.contrib.postgres.validators import ArrayMaxLengthValidator
  5. from django.core import checks, exceptions
  6. from django.db.models import Field, IntegerField, Transform
  7. from django.db.models.lookups import Exact, In
  8. from django.utils.inspect import func_supports_parameter
  9. from django.utils.translation import gettext_lazy as _
  10. from ..utils import prefix_validation_error
  11. from .mixins import CheckFieldDefaultMixin
  12. from .utils import AttributeSetter
  13. __all__ = ['ArrayField']
  14. class ArrayField(CheckFieldDefaultMixin, Field):
  15. empty_strings_allowed = False
  16. default_error_messages = {
  17. 'item_invalid': _('Item %(nth)s in the array did not validate:'),
  18. 'nested_array_mismatch': _('Nested arrays must have the same length.'),
  19. }
  20. _default_hint = ('list', '[]')
  21. def __init__(self, base_field, size=None, **kwargs):
  22. self.base_field = base_field
  23. self.size = size
  24. if self.size:
  25. self.default_validators = self.default_validators[:]
  26. self.default_validators.append(ArrayMaxLengthValidator(self.size))
  27. # For performance, only add a from_db_value() method if the base field
  28. # implements it.
  29. if hasattr(self.base_field, 'from_db_value'):
  30. self.from_db_value = self._from_db_value
  31. super().__init__(**kwargs)
  32. @property
  33. def model(self):
  34. try:
  35. return self.__dict__['model']
  36. except KeyError:
  37. raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
  38. @model.setter
  39. def model(self, model):
  40. self.__dict__['model'] = model
  41. self.base_field.model = model
  42. def check(self, **kwargs):
  43. errors = super().check(**kwargs)
  44. if self.base_field.remote_field:
  45. errors.append(
  46. checks.Error(
  47. 'Base field for array cannot be a related field.',
  48. obj=self,
  49. id='postgres.E002'
  50. )
  51. )
  52. else:
  53. # Remove the field name checks as they are not needed here.
  54. base_errors = self.base_field.check()
  55. if base_errors:
  56. messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors)
  57. errors.append(
  58. checks.Error(
  59. 'Base field for array has errors:\n %s' % messages,
  60. obj=self,
  61. id='postgres.E001'
  62. )
  63. )
  64. return errors
  65. def set_attributes_from_name(self, name):
  66. super().set_attributes_from_name(name)
  67. self.base_field.set_attributes_from_name(name)
  68. @property
  69. def description(self):
  70. return 'Array of %s' % self.base_field.description
  71. def db_type(self, connection):
  72. size = self.size or ''
  73. return '%s[%s]' % (self.base_field.db_type(connection), size)
  74. def get_db_prep_value(self, value, connection, prepared=False):
  75. if isinstance(value, (list, tuple)):
  76. return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value]
  77. return value
  78. def deconstruct(self):
  79. name, path, args, kwargs = super().deconstruct()
  80. if path == 'django.contrib.postgres.fields.array.ArrayField':
  81. path = 'django.contrib.postgres.fields.ArrayField'
  82. kwargs.update({
  83. 'base_field': self.base_field.clone(),
  84. 'size': self.size,
  85. })
  86. return name, path, args, kwargs
  87. def to_python(self, value):
  88. if isinstance(value, str):
  89. # Assume we're deserializing
  90. vals = json.loads(value)
  91. value = [self.base_field.to_python(val) for val in vals]
  92. return value
  93. def _from_db_value(self, value, expression, connection):
  94. if value is None:
  95. return value
  96. return [
  97. self.base_field.from_db_value(item, expression, connection, {})
  98. if func_supports_parameter(self.base_field.from_db_value, 'context') # RemovedInDjango30Warning
  99. else self.base_field.from_db_value(item, expression, connection)
  100. for item in value
  101. ]
  102. def value_to_string(self, obj):
  103. values = []
  104. vals = self.value_from_object(obj)
  105. base_field = self.base_field
  106. for val in vals:
  107. if val is None:
  108. values.append(None)
  109. else:
  110. obj = AttributeSetter(base_field.attname, val)
  111. values.append(base_field.value_to_string(obj))
  112. return json.dumps(values)
  113. def get_transform(self, name):
  114. transform = super().get_transform(name)
  115. if transform:
  116. return transform
  117. if '_' not in name:
  118. try:
  119. index = int(name)
  120. except ValueError:
  121. pass
  122. else:
  123. index += 1 # postgres uses 1-indexing
  124. return IndexTransformFactory(index, self.base_field)
  125. try:
  126. start, end = name.split('_')
  127. start = int(start) + 1
  128. end = int(end) # don't add one here because postgres slices are weird
  129. except ValueError:
  130. pass
  131. else:
  132. return SliceTransformFactory(start, end)
  133. def validate(self, value, model_instance):
  134. super().validate(value, model_instance)
  135. for index, part in enumerate(value):
  136. try:
  137. self.base_field.validate(part, model_instance)
  138. except exceptions.ValidationError as error:
  139. raise prefix_validation_error(
  140. error,
  141. prefix=self.error_messages['item_invalid'],
  142. code='item_invalid',
  143. params={'nth': index + 1},
  144. )
  145. if isinstance(self.base_field, ArrayField):
  146. if len({len(i) for i in value}) > 1:
  147. raise exceptions.ValidationError(
  148. self.error_messages['nested_array_mismatch'],
  149. code='nested_array_mismatch',
  150. )
  151. def run_validators(self, value):
  152. super().run_validators(value)
  153. for index, part in enumerate(value):
  154. try:
  155. self.base_field.run_validators(part)
  156. except exceptions.ValidationError as error:
  157. raise prefix_validation_error(
  158. error,
  159. prefix=self.error_messages['item_invalid'],
  160. code='item_invalid',
  161. params={'nth': index + 1},
  162. )
  163. def formfield(self, **kwargs):
  164. return super().formfield(**{
  165. 'form_class': SimpleArrayField,
  166. 'base_field': self.base_field.formfield(),
  167. 'max_length': self.size,
  168. **kwargs,
  169. })
  170. @ArrayField.register_lookup
  171. class ArrayContains(lookups.DataContains):
  172. def as_sql(self, qn, connection):
  173. sql, params = super().as_sql(qn, connection)
  174. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  175. return sql, params
  176. @ArrayField.register_lookup
  177. class ArrayContainedBy(lookups.ContainedBy):
  178. def as_sql(self, qn, connection):
  179. sql, params = super().as_sql(qn, connection)
  180. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  181. return sql, params
  182. @ArrayField.register_lookup
  183. class ArrayExact(Exact):
  184. def as_sql(self, qn, connection):
  185. sql, params = super().as_sql(qn, connection)
  186. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  187. return sql, params
  188. @ArrayField.register_lookup
  189. class ArrayOverlap(lookups.Overlap):
  190. def as_sql(self, qn, connection):
  191. sql, params = super().as_sql(qn, connection)
  192. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  193. return sql, params
  194. @ArrayField.register_lookup
  195. class ArrayLenTransform(Transform):
  196. lookup_name = 'len'
  197. output_field = IntegerField()
  198. def as_sql(self, compiler, connection):
  199. lhs, params = compiler.compile(self.lhs)
  200. # Distinguish NULL and empty arrays
  201. return (
  202. 'CASE WHEN %(lhs)s IS NULL THEN NULL ELSE '
  203. 'coalesce(array_length(%(lhs)s, 1), 0) END'
  204. ) % {'lhs': lhs}, params
  205. @ArrayField.register_lookup
  206. class ArrayInLookup(In):
  207. def get_prep_lookup(self):
  208. values = super().get_prep_lookup()
  209. if hasattr(self.rhs, '_prepare'):
  210. # Subqueries don't need further preparation.
  211. return values
  212. # In.process_rhs() expects values to be hashable, so convert lists
  213. # to tuples.
  214. prepared_values = []
  215. for value in values:
  216. if hasattr(value, 'resolve_expression'):
  217. prepared_values.append(value)
  218. else:
  219. prepared_values.append(tuple(value))
  220. return prepared_values
  221. class IndexTransform(Transform):
  222. def __init__(self, index, base_field, *args, **kwargs):
  223. super().__init__(*args, **kwargs)
  224. self.index = index
  225. self.base_field = base_field
  226. def as_sql(self, compiler, connection):
  227. lhs, params = compiler.compile(self.lhs)
  228. return '%s[%s]' % (lhs, self.index), params
  229. @property
  230. def output_field(self):
  231. return self.base_field
  232. class IndexTransformFactory:
  233. def __init__(self, index, base_field):
  234. self.index = index
  235. self.base_field = base_field
  236. def __call__(self, *args, **kwargs):
  237. return IndexTransform(self.index, self.base_field, *args, **kwargs)
  238. class SliceTransform(Transform):
  239. def __init__(self, start, end, *args, **kwargs):
  240. super().__init__(*args, **kwargs)
  241. self.start = start
  242. self.end = end
  243. def as_sql(self, compiler, connection):
  244. lhs, params = compiler.compile(self.lhs)
  245. return '%s[%s:%s]' % (lhs, self.start, self.end), params
  246. class SliceTransformFactory:
  247. def __init__(self, start, end):
  248. self.start = start
  249. self.end = end
  250. def __call__(self, *args, **kwargs):
  251. return SliceTransform(self.start, self.end, *args, **kwargs)