123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305 |
- import json
-
- from django.contrib.postgres import lookups
- from django.contrib.postgres.forms import SimpleArrayField
- from django.contrib.postgres.validators import ArrayMaxLengthValidator
- from django.core import checks, exceptions
- from django.db.models import Field, IntegerField, Transform
- from django.db.models.lookups import Exact, In
- from django.utils.inspect import func_supports_parameter
- from django.utils.translation import gettext_lazy as _
-
- from ..utils import prefix_validation_error
- from .mixins import CheckFieldDefaultMixin
- from .utils import AttributeSetter
-
- __all__ = ['ArrayField']
-
-
- class ArrayField(CheckFieldDefaultMixin, Field):
- empty_strings_allowed = False
- default_error_messages = {
- 'item_invalid': _('Item %(nth)s in the array did not validate:'),
- 'nested_array_mismatch': _('Nested arrays must have the same length.'),
- }
- _default_hint = ('list', '[]')
-
- def __init__(self, base_field, size=None, **kwargs):
- self.base_field = base_field
- self.size = size
- if self.size:
- self.default_validators = [*self.default_validators, ArrayMaxLengthValidator(self.size)]
- # For performance, only add a from_db_value() method if the base field
- # implements it.
- if hasattr(self.base_field, 'from_db_value'):
- self.from_db_value = self._from_db_value
- super().__init__(**kwargs)
-
- @property
- def model(self):
- try:
- return self.__dict__['model']
- except KeyError:
- raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
-
- @model.setter
- def model(self, model):
- self.__dict__['model'] = model
- self.base_field.model = model
-
- def check(self, **kwargs):
- errors = super().check(**kwargs)
- if self.base_field.remote_field:
- errors.append(
- checks.Error(
- 'Base field for array cannot be a related field.',
- obj=self,
- id='postgres.E002'
- )
- )
- else:
- # Remove the field name checks as they are not needed here.
- base_errors = self.base_field.check()
- if base_errors:
- messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors)
- errors.append(
- checks.Error(
- 'Base field for array has errors:\n %s' % messages,
- obj=self,
- id='postgres.E001'
- )
- )
- return errors
-
- def set_attributes_from_name(self, name):
- super().set_attributes_from_name(name)
- self.base_field.set_attributes_from_name(name)
-
- @property
- def description(self):
- return 'Array of %s' % self.base_field.description
-
- def db_type(self, connection):
- size = self.size or ''
- return '%s[%s]' % (self.base_field.db_type(connection), size)
-
- def get_placeholder(self, value, compiler, connection):
- return '%s::{}'.format(self.db_type(connection))
-
- def get_db_prep_value(self, value, connection, prepared=False):
- if isinstance(value, (list, tuple)):
- return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value]
- return value
-
- def deconstruct(self):
- name, path, args, kwargs = super().deconstruct()
- if path == 'django.contrib.postgres.fields.array.ArrayField':
- path = 'django.contrib.postgres.fields.ArrayField'
- kwargs.update({
- 'base_field': self.base_field.clone(),
- 'size': self.size,
- })
- return name, path, args, kwargs
-
- def to_python(self, value):
- if isinstance(value, str):
- # Assume we're deserializing
- vals = json.loads(value)
- value = [self.base_field.to_python(val) for val in vals]
- return value
-
- def _from_db_value(self, value, expression, connection):
- if value is None:
- return value
- return [
- self.base_field.from_db_value(item, expression, connection, {})
- if func_supports_parameter(self.base_field.from_db_value, 'context') # RemovedInDjango30Warning
- else self.base_field.from_db_value(item, expression, connection)
- for item in value
- ]
-
- def value_to_string(self, obj):
- values = []
- vals = self.value_from_object(obj)
- base_field = self.base_field
-
- for val in vals:
- if val is None:
- values.append(None)
- else:
- obj = AttributeSetter(base_field.attname, val)
- values.append(base_field.value_to_string(obj))
- return json.dumps(values)
-
- def get_transform(self, name):
- transform = super().get_transform(name)
- if transform:
- return transform
- if '_' not in name:
- try:
- index = int(name)
- except ValueError:
- pass
- else:
- index += 1 # postgres uses 1-indexing
- return IndexTransformFactory(index, self.base_field)
- try:
- start, end = name.split('_')
- start = int(start) + 1
- end = int(end) # don't add one here because postgres slices are weird
- except ValueError:
- pass
- else:
- return SliceTransformFactory(start, end)
-
- def validate(self, value, model_instance):
- super().validate(value, model_instance)
- for index, part in enumerate(value):
- try:
- self.base_field.validate(part, model_instance)
- except exceptions.ValidationError as error:
- raise prefix_validation_error(
- error,
- prefix=self.error_messages['item_invalid'],
- code='item_invalid',
- params={'nth': index + 1},
- )
- if isinstance(self.base_field, ArrayField):
- if len({len(i) for i in value}) > 1:
- raise exceptions.ValidationError(
- self.error_messages['nested_array_mismatch'],
- code='nested_array_mismatch',
- )
-
- def run_validators(self, value):
- super().run_validators(value)
- for index, part in enumerate(value):
- try:
- self.base_field.run_validators(part)
- except exceptions.ValidationError as error:
- raise prefix_validation_error(
- error,
- prefix=self.error_messages['item_invalid'],
- code='item_invalid',
- params={'nth': index + 1},
- )
-
- def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': SimpleArrayField,
- 'base_field': self.base_field.formfield(),
- 'max_length': self.size,
- **kwargs,
- })
-
-
- @ArrayField.register_lookup
- class ArrayContains(lookups.DataContains):
- def as_sql(self, qn, connection):
- sql, params = super().as_sql(qn, connection)
- sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
- return sql, params
-
-
- @ArrayField.register_lookup
- class ArrayContainedBy(lookups.ContainedBy):
- def as_sql(self, qn, connection):
- sql, params = super().as_sql(qn, connection)
- sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
- return sql, params
-
-
- @ArrayField.register_lookup
- class ArrayExact(Exact):
- def as_sql(self, qn, connection):
- sql, params = super().as_sql(qn, connection)
- sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
- return sql, params
-
-
- @ArrayField.register_lookup
- class ArrayOverlap(lookups.Overlap):
- def as_sql(self, qn, connection):
- sql, params = super().as_sql(qn, connection)
- sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
- return sql, params
-
-
- @ArrayField.register_lookup
- class ArrayLenTransform(Transform):
- lookup_name = 'len'
- output_field = IntegerField()
-
- def as_sql(self, compiler, connection):
- lhs, params = compiler.compile(self.lhs)
- # Distinguish NULL and empty arrays
- return (
- 'CASE WHEN %(lhs)s IS NULL THEN NULL ELSE '
- 'coalesce(array_length(%(lhs)s, 1), 0) END'
- ) % {'lhs': lhs}, params
-
-
- @ArrayField.register_lookup
- class ArrayInLookup(In):
- def get_prep_lookup(self):
- values = super().get_prep_lookup()
- if hasattr(self.rhs, '_prepare'):
- # Subqueries don't need further preparation.
- return values
- # In.process_rhs() expects values to be hashable, so convert lists
- # to tuples.
- prepared_values = []
- for value in values:
- if hasattr(value, 'resolve_expression'):
- prepared_values.append(value)
- else:
- prepared_values.append(tuple(value))
- return prepared_values
-
-
- class IndexTransform(Transform):
-
- def __init__(self, index, base_field, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.index = index
- self.base_field = base_field
-
- def as_sql(self, compiler, connection):
- lhs, params = compiler.compile(self.lhs)
- return '%s[%s]' % (lhs, self.index), params
-
- @property
- def output_field(self):
- return self.base_field
-
-
- class IndexTransformFactory:
-
- def __init__(self, index, base_field):
- self.index = index
- self.base_field = base_field
-
- def __call__(self, *args, **kwargs):
- return IndexTransform(self.index, self.base_field, *args, **kwargs)
-
-
- class SliceTransform(Transform):
-
- def __init__(self, start, end, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.start = start
- self.end = end
-
- def as_sql(self, compiler, connection):
- lhs, params = compiler.compile(self.lhs)
- return '%s[%s:%s]' % (lhs, self.start, self.end), params
-
-
- class SliceTransformFactory:
-
- def __init__(self, start, end):
- self.start = start
- self.end = end
-
- def __call__(self, *args, **kwargs):
- return SliceTransform(self.start, self.end, *args, **kwargs)
|