You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import json
  2. from django.contrib.postgres import lookups
  3. from django.contrib.postgres.forms import SimpleArrayField
  4. from django.contrib.postgres.validators import ArrayMaxLengthValidator
  5. from django.core import checks, exceptions
  6. from django.db.models import Field, IntegerField, Transform
  7. from django.db.models.lookups import Exact, In
  8. from django.utils.inspect import func_supports_parameter
  9. from django.utils.translation import gettext_lazy as _
  10. from ..utils import prefix_validation_error
  11. from .mixins import CheckFieldDefaultMixin
  12. from .utils import AttributeSetter
  13. __all__ = ['ArrayField']
  14. class ArrayField(CheckFieldDefaultMixin, Field):
  15. empty_strings_allowed = False
  16. default_error_messages = {
  17. 'item_invalid': _('Item %(nth)s in the array did not validate:'),
  18. 'nested_array_mismatch': _('Nested arrays must have the same length.'),
  19. }
  20. _default_hint = ('list', '[]')
  21. def __init__(self, base_field, size=None, **kwargs):
  22. self.base_field = base_field
  23. self.size = size
  24. if self.size:
  25. self.default_validators = [*self.default_validators, ArrayMaxLengthValidator(self.size)]
  26. # For performance, only add a from_db_value() method if the base field
  27. # implements it.
  28. if hasattr(self.base_field, 'from_db_value'):
  29. self.from_db_value = self._from_db_value
  30. super().__init__(**kwargs)
  31. @property
  32. def model(self):
  33. try:
  34. return self.__dict__['model']
  35. except KeyError:
  36. raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
  37. @model.setter
  38. def model(self, model):
  39. self.__dict__['model'] = model
  40. self.base_field.model = model
  41. def check(self, **kwargs):
  42. errors = super().check(**kwargs)
  43. if self.base_field.remote_field:
  44. errors.append(
  45. checks.Error(
  46. 'Base field for array cannot be a related field.',
  47. obj=self,
  48. id='postgres.E002'
  49. )
  50. )
  51. else:
  52. # Remove the field name checks as they are not needed here.
  53. base_errors = self.base_field.check()
  54. if base_errors:
  55. messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors)
  56. errors.append(
  57. checks.Error(
  58. 'Base field for array has errors:\n %s' % messages,
  59. obj=self,
  60. id='postgres.E001'
  61. )
  62. )
  63. return errors
  64. def set_attributes_from_name(self, name):
  65. super().set_attributes_from_name(name)
  66. self.base_field.set_attributes_from_name(name)
  67. @property
  68. def description(self):
  69. return 'Array of %s' % self.base_field.description
  70. def db_type(self, connection):
  71. size = self.size or ''
  72. return '%s[%s]' % (self.base_field.db_type(connection), size)
  73. def get_placeholder(self, value, compiler, connection):
  74. return '%s::{}'.format(self.db_type(connection))
  75. def get_db_prep_value(self, value, connection, prepared=False):
  76. if isinstance(value, (list, tuple)):
  77. return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value]
  78. return value
  79. def deconstruct(self):
  80. name, path, args, kwargs = super().deconstruct()
  81. if path == 'django.contrib.postgres.fields.array.ArrayField':
  82. path = 'django.contrib.postgres.fields.ArrayField'
  83. kwargs.update({
  84. 'base_field': self.base_field.clone(),
  85. 'size': self.size,
  86. })
  87. return name, path, args, kwargs
  88. def to_python(self, value):
  89. if isinstance(value, str):
  90. # Assume we're deserializing
  91. vals = json.loads(value)
  92. value = [self.base_field.to_python(val) for val in vals]
  93. return value
  94. def _from_db_value(self, value, expression, connection):
  95. if value is None:
  96. return value
  97. return [
  98. self.base_field.from_db_value(item, expression, connection, {})
  99. if func_supports_parameter(self.base_field.from_db_value, 'context') # RemovedInDjango30Warning
  100. else self.base_field.from_db_value(item, expression, connection)
  101. for item in value
  102. ]
  103. def value_to_string(self, obj):
  104. values = []
  105. vals = self.value_from_object(obj)
  106. base_field = self.base_field
  107. for val in vals:
  108. if val is None:
  109. values.append(None)
  110. else:
  111. obj = AttributeSetter(base_field.attname, val)
  112. values.append(base_field.value_to_string(obj))
  113. return json.dumps(values)
  114. def get_transform(self, name):
  115. transform = super().get_transform(name)
  116. if transform:
  117. return transform
  118. if '_' not in name:
  119. try:
  120. index = int(name)
  121. except ValueError:
  122. pass
  123. else:
  124. index += 1 # postgres uses 1-indexing
  125. return IndexTransformFactory(index, self.base_field)
  126. try:
  127. start, end = name.split('_')
  128. start = int(start) + 1
  129. end = int(end) # don't add one here because postgres slices are weird
  130. except ValueError:
  131. pass
  132. else:
  133. return SliceTransformFactory(start, end)
  134. def validate(self, value, model_instance):
  135. super().validate(value, model_instance)
  136. for index, part in enumerate(value):
  137. try:
  138. self.base_field.validate(part, model_instance)
  139. except exceptions.ValidationError as error:
  140. raise prefix_validation_error(
  141. error,
  142. prefix=self.error_messages['item_invalid'],
  143. code='item_invalid',
  144. params={'nth': index + 1},
  145. )
  146. if isinstance(self.base_field, ArrayField):
  147. if len({len(i) for i in value}) > 1:
  148. raise exceptions.ValidationError(
  149. self.error_messages['nested_array_mismatch'],
  150. code='nested_array_mismatch',
  151. )
  152. def run_validators(self, value):
  153. super().run_validators(value)
  154. for index, part in enumerate(value):
  155. try:
  156. self.base_field.run_validators(part)
  157. except exceptions.ValidationError as error:
  158. raise prefix_validation_error(
  159. error,
  160. prefix=self.error_messages['item_invalid'],
  161. code='item_invalid',
  162. params={'nth': index + 1},
  163. )
  164. def formfield(self, **kwargs):
  165. return super().formfield(**{
  166. 'form_class': SimpleArrayField,
  167. 'base_field': self.base_field.formfield(),
  168. 'max_length': self.size,
  169. **kwargs,
  170. })
  171. @ArrayField.register_lookup
  172. class ArrayContains(lookups.DataContains):
  173. def as_sql(self, qn, connection):
  174. sql, params = super().as_sql(qn, connection)
  175. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  176. return sql, params
  177. @ArrayField.register_lookup
  178. class ArrayContainedBy(lookups.ContainedBy):
  179. def as_sql(self, qn, connection):
  180. sql, params = super().as_sql(qn, connection)
  181. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  182. return sql, params
  183. @ArrayField.register_lookup
  184. class ArrayExact(Exact):
  185. def as_sql(self, qn, connection):
  186. sql, params = super().as_sql(qn, connection)
  187. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  188. return sql, params
  189. @ArrayField.register_lookup
  190. class ArrayOverlap(lookups.Overlap):
  191. def as_sql(self, qn, connection):
  192. sql, params = super().as_sql(qn, connection)
  193. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  194. return sql, params
  195. @ArrayField.register_lookup
  196. class ArrayLenTransform(Transform):
  197. lookup_name = 'len'
  198. output_field = IntegerField()
  199. def as_sql(self, compiler, connection):
  200. lhs, params = compiler.compile(self.lhs)
  201. # Distinguish NULL and empty arrays
  202. return (
  203. 'CASE WHEN %(lhs)s IS NULL THEN NULL ELSE '
  204. 'coalesce(array_length(%(lhs)s, 1), 0) END'
  205. ) % {'lhs': lhs}, params
  206. @ArrayField.register_lookup
  207. class ArrayInLookup(In):
  208. def get_prep_lookup(self):
  209. values = super().get_prep_lookup()
  210. if hasattr(self.rhs, '_prepare'):
  211. # Subqueries don't need further preparation.
  212. return values
  213. # In.process_rhs() expects values to be hashable, so convert lists
  214. # to tuples.
  215. prepared_values = []
  216. for value in values:
  217. if hasattr(value, 'resolve_expression'):
  218. prepared_values.append(value)
  219. else:
  220. prepared_values.append(tuple(value))
  221. return prepared_values
  222. class IndexTransform(Transform):
  223. def __init__(self, index, base_field, *args, **kwargs):
  224. super().__init__(*args, **kwargs)
  225. self.index = index
  226. self.base_field = base_field
  227. def as_sql(self, compiler, connection):
  228. lhs, params = compiler.compile(self.lhs)
  229. return '%s[%s]' % (lhs, self.index), params
  230. @property
  231. def output_field(self):
  232. return self.base_field
  233. class IndexTransformFactory:
  234. def __init__(self, index, base_field):
  235. self.index = index
  236. self.base_field = base_field
  237. def __call__(self, *args, **kwargs):
  238. return IndexTransform(self.index, self.base_field, *args, **kwargs)
  239. class SliceTransform(Transform):
  240. def __init__(self, start, end, *args, **kwargs):
  241. super().__init__(*args, **kwargs)
  242. self.start = start
  243. self.end = end
  244. def as_sql(self, compiler, connection):
  245. lhs, params = compiler.compile(self.lhs)
  246. return '%s[%s:%s]' % (lhs, self.start, self.end), params
  247. class SliceTransformFactory:
  248. def __init__(self, start, end):
  249. self.start = start
  250. self.end = end
  251. def __call__(self, *args, **kwargs):
  252. return SliceTransform(self.start, self.end, *args, **kwargs)