Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

array.py 10.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. import json
  2. from django.contrib.postgres import lookups
  3. from django.contrib.postgres.forms import SimpleArrayField
  4. from django.contrib.postgres.validators import ArrayMaxLengthValidator
  5. from django.core import checks, exceptions
  6. from django.db.models import Field, IntegerField, Transform
  7. from django.db.models.lookups import Exact, In
  8. from django.utils.inspect import func_supports_parameter
  9. from django.utils.translation import gettext_lazy as _
  10. from ..utils import prefix_validation_error
  11. from .utils import AttributeSetter
  12. __all__ = ['ArrayField']
  13. class ArrayField(Field):
  14. empty_strings_allowed = False
  15. default_error_messages = {
  16. 'item_invalid': _('Item %(nth)s in the array did not validate: '),
  17. 'nested_array_mismatch': _('Nested arrays must have the same length.'),
  18. }
  19. def __init__(self, base_field, size=None, **kwargs):
  20. self.base_field = base_field
  21. self.size = size
  22. if self.size:
  23. self.default_validators = self.default_validators[:]
  24. self.default_validators.append(ArrayMaxLengthValidator(self.size))
  25. # For performance, only add a from_db_value() method if the base field
  26. # implements it.
  27. if hasattr(self.base_field, 'from_db_value'):
  28. self.from_db_value = self._from_db_value
  29. super().__init__(**kwargs)
  30. @property
  31. def model(self):
  32. try:
  33. return self.__dict__['model']
  34. except KeyError:
  35. raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
  36. @model.setter
  37. def model(self, model):
  38. self.__dict__['model'] = model
  39. self.base_field.model = model
  40. def check(self, **kwargs):
  41. errors = super().check(**kwargs)
  42. if self.base_field.remote_field:
  43. errors.append(
  44. checks.Error(
  45. 'Base field for array cannot be a related field.',
  46. obj=self,
  47. id='postgres.E002'
  48. )
  49. )
  50. else:
  51. # Remove the field name checks as they are not needed here.
  52. base_errors = self.base_field.check()
  53. if base_errors:
  54. messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors)
  55. errors.append(
  56. checks.Error(
  57. 'Base field for array has errors:\n %s' % messages,
  58. obj=self,
  59. id='postgres.E001'
  60. )
  61. )
  62. return errors
  63. def set_attributes_from_name(self, name):
  64. super().set_attributes_from_name(name)
  65. self.base_field.set_attributes_from_name(name)
  66. @property
  67. def description(self):
  68. return 'Array of %s' % self.base_field.description
  69. def db_type(self, connection):
  70. size = self.size or ''
  71. return '%s[%s]' % (self.base_field.db_type(connection), size)
  72. def get_db_prep_value(self, value, connection, prepared=False):
  73. if isinstance(value, (list, tuple)):
  74. return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value]
  75. return value
  76. def deconstruct(self):
  77. name, path, args, kwargs = super().deconstruct()
  78. if path == 'django.contrib.postgres.fields.array.ArrayField':
  79. path = 'django.contrib.postgres.fields.ArrayField'
  80. kwargs.update({
  81. 'base_field': self.base_field.clone(),
  82. 'size': self.size,
  83. })
  84. return name, path, args, kwargs
  85. def to_python(self, value):
  86. if isinstance(value, str):
  87. # Assume we're deserializing
  88. vals = json.loads(value)
  89. value = [self.base_field.to_python(val) for val in vals]
  90. return value
  91. def _from_db_value(self, value, expression, connection):
  92. if value is None:
  93. return value
  94. return [
  95. self.base_field.from_db_value(item, expression, connection, {})
  96. if func_supports_parameter(self.base_field.from_db_value, 'context') # RemovedInDjango30Warning
  97. else self.base_field.from_db_value(item, expression, connection)
  98. for item in value
  99. ]
  100. def value_to_string(self, obj):
  101. values = []
  102. vals = self.value_from_object(obj)
  103. base_field = self.base_field
  104. for val in vals:
  105. if val is None:
  106. values.append(None)
  107. else:
  108. obj = AttributeSetter(base_field.attname, val)
  109. values.append(base_field.value_to_string(obj))
  110. return json.dumps(values)
  111. def get_transform(self, name):
  112. transform = super().get_transform(name)
  113. if transform:
  114. return transform
  115. if '_' not in name:
  116. try:
  117. index = int(name)
  118. except ValueError:
  119. pass
  120. else:
  121. index += 1 # postgres uses 1-indexing
  122. return IndexTransformFactory(index, self.base_field)
  123. try:
  124. start, end = name.split('_')
  125. start = int(start) + 1
  126. end = int(end) # don't add one here because postgres slices are weird
  127. except ValueError:
  128. pass
  129. else:
  130. return SliceTransformFactory(start, end)
  131. def validate(self, value, model_instance):
  132. super().validate(value, model_instance)
  133. for index, part in enumerate(value):
  134. try:
  135. self.base_field.validate(part, model_instance)
  136. except exceptions.ValidationError as error:
  137. raise prefix_validation_error(
  138. error,
  139. prefix=self.error_messages['item_invalid'],
  140. code='item_invalid',
  141. params={'nth': index},
  142. )
  143. if isinstance(self.base_field, ArrayField):
  144. if len({len(i) for i in value}) > 1:
  145. raise exceptions.ValidationError(
  146. self.error_messages['nested_array_mismatch'],
  147. code='nested_array_mismatch',
  148. )
  149. def run_validators(self, value):
  150. super().run_validators(value)
  151. for index, part in enumerate(value):
  152. try:
  153. self.base_field.run_validators(part)
  154. except exceptions.ValidationError as error:
  155. raise prefix_validation_error(
  156. error,
  157. prefix=self.error_messages['item_invalid'],
  158. code='item_invalid',
  159. params={'nth': index},
  160. )
  161. def formfield(self, **kwargs):
  162. defaults = {
  163. 'form_class': SimpleArrayField,
  164. 'base_field': self.base_field.formfield(),
  165. 'max_length': self.size,
  166. }
  167. defaults.update(kwargs)
  168. return super().formfield(**defaults)
  169. @ArrayField.register_lookup
  170. class ArrayContains(lookups.DataContains):
  171. def as_sql(self, qn, connection):
  172. sql, params = super().as_sql(qn, connection)
  173. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  174. return sql, params
  175. @ArrayField.register_lookup
  176. class ArrayContainedBy(lookups.ContainedBy):
  177. def as_sql(self, qn, connection):
  178. sql, params = super().as_sql(qn, connection)
  179. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  180. return sql, params
  181. @ArrayField.register_lookup
  182. class ArrayExact(Exact):
  183. def as_sql(self, qn, connection):
  184. sql, params = super().as_sql(qn, connection)
  185. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  186. return sql, params
  187. @ArrayField.register_lookup
  188. class ArrayOverlap(lookups.Overlap):
  189. def as_sql(self, qn, connection):
  190. sql, params = super().as_sql(qn, connection)
  191. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  192. return sql, params
  193. @ArrayField.register_lookup
  194. class ArrayLenTransform(Transform):
  195. lookup_name = 'len'
  196. output_field = IntegerField()
  197. def as_sql(self, compiler, connection):
  198. lhs, params = compiler.compile(self.lhs)
  199. # Distinguish NULL and empty arrays
  200. return (
  201. 'CASE WHEN %(lhs)s IS NULL THEN NULL ELSE '
  202. 'coalesce(array_length(%(lhs)s, 1), 0) END'
  203. ) % {'lhs': lhs}, params
  204. @ArrayField.register_lookup
  205. class ArrayInLookup(In):
  206. def get_prep_lookup(self):
  207. values = super().get_prep_lookup()
  208. if hasattr(self.rhs, '_prepare'):
  209. # Subqueries don't need further preparation.
  210. return values
  211. # In.process_rhs() expects values to be hashable, so convert lists
  212. # to tuples.
  213. prepared_values = []
  214. for value in values:
  215. if hasattr(value, 'resolve_expression'):
  216. prepared_values.append(value)
  217. else:
  218. prepared_values.append(tuple(value))
  219. return prepared_values
  220. class IndexTransform(Transform):
  221. def __init__(self, index, base_field, *args, **kwargs):
  222. super().__init__(*args, **kwargs)
  223. self.index = index
  224. self.base_field = base_field
  225. def as_sql(self, compiler, connection):
  226. lhs, params = compiler.compile(self.lhs)
  227. return '%s[%s]' % (lhs, self.index), params
  228. @property
  229. def output_field(self):
  230. return self.base_field
  231. class IndexTransformFactory:
  232. def __init__(self, index, base_field):
  233. self.index = index
  234. self.base_field = base_field
  235. def __call__(self, *args, **kwargs):
  236. return IndexTransform(self.index, self.base_field, *args, **kwargs)
  237. class SliceTransform(Transform):
  238. def __init__(self, start, end, *args, **kwargs):
  239. super().__init__(*args, **kwargs)
  240. self.start = start
  241. self.end = end
  242. def as_sql(self, compiler, connection):
  243. lhs, params = compiler.compile(self.lhs)
  244. return '%s[%s:%s]' % (lhs, self.start, self.end), params
  245. class SliceTransformFactory:
  246. def __init__(self, start, end):
  247. self.start = start
  248. self.end = end
  249. def __call__(self, *args, **kwargs):
  250. return SliceTransform(self.start, self.end, *args, **kwargs)