123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274 |
- import datetime
- import json
-
- from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
-
- from django.contrib.postgres import forms, lookups
- from django.db import models
-
- from .utils import AttributeSetter
-
- __all__ = [
- 'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
- 'DecimalRangeField', 'DateTimeRangeField', 'DateRangeField',
- 'FloatRangeField',
- ]
-
-
- class RangeField(models.Field):
- empty_strings_allowed = False
-
- def __init__(self, *args, **kwargs):
- # Initializing base_field here ensures that its model matches the model for self.
- if hasattr(self, 'base_field'):
- self.base_field = self.base_field()
- super().__init__(*args, **kwargs)
-
- @property
- def model(self):
- try:
- return self.__dict__['model']
- except KeyError:
- raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
-
- @model.setter
- def model(self, model):
- self.__dict__['model'] = model
- self.base_field.model = model
-
- def get_prep_value(self, value):
- if value is None:
- return None
- elif isinstance(value, Range):
- return value
- elif isinstance(value, (list, tuple)):
- return self.range_type(value[0], value[1])
- return value
-
- def to_python(self, value):
- if isinstance(value, str):
- # Assume we're deserializing
- vals = json.loads(value)
- for end in ('lower', 'upper'):
- if end in vals:
- vals[end] = self.base_field.to_python(vals[end])
- value = self.range_type(**vals)
- elif isinstance(value, (list, tuple)):
- value = self.range_type(value[0], value[1])
- return value
-
- def set_attributes_from_name(self, name):
- super().set_attributes_from_name(name)
- self.base_field.set_attributes_from_name(name)
-
- def value_to_string(self, obj):
- value = self.value_from_object(obj)
- if value is None:
- return None
- if value.isempty:
- return json.dumps({"empty": True})
- base_field = self.base_field
- result = {"bounds": value._bounds}
- for end in ('lower', 'upper'):
- val = getattr(value, end)
- if val is None:
- result[end] = None
- else:
- obj = AttributeSetter(base_field.attname, val)
- result[end] = base_field.value_to_string(obj)
- return json.dumps(result)
-
- def formfield(self, **kwargs):
- kwargs.setdefault('form_class', self.form_field)
- return super().formfield(**kwargs)
-
-
- class IntegerRangeField(RangeField):
- base_field = models.IntegerField
- range_type = NumericRange
- form_field = forms.IntegerRangeField
-
- def db_type(self, connection):
- return 'int4range'
-
-
- class BigIntegerRangeField(RangeField):
- base_field = models.BigIntegerField
- range_type = NumericRange
- form_field = forms.IntegerRangeField
-
- def db_type(self, connection):
- return 'int8range'
-
-
- class DecimalRangeField(RangeField):
- base_field = models.DecimalField
- range_type = NumericRange
- form_field = forms.DecimalRangeField
-
- def db_type(self, connection):
- return 'numrange'
-
-
- class FloatRangeField(RangeField):
- system_check_deprecated_details = {
- 'msg': (
- 'FloatRangeField is deprecated and will be removed in Django 3.1.'
- ),
- 'hint': 'Use DecimalRangeField instead.',
- 'id': 'fields.W902',
- }
- base_field = models.FloatField
- range_type = NumericRange
- form_field = forms.FloatRangeField
-
- def db_type(self, connection):
- return 'numrange'
-
-
- class DateTimeRangeField(RangeField):
- base_field = models.DateTimeField
- range_type = DateTimeTZRange
- form_field = forms.DateTimeRangeField
-
- def db_type(self, connection):
- return 'tstzrange'
-
-
- class DateRangeField(RangeField):
- base_field = models.DateField
- range_type = DateRange
- form_field = forms.DateRangeField
-
- def db_type(self, connection):
- return 'daterange'
-
-
- RangeField.register_lookup(lookups.DataContains)
- RangeField.register_lookup(lookups.ContainedBy)
- RangeField.register_lookup(lookups.Overlap)
-
-
- class DateTimeRangeContains(models.Lookup):
- """
- Lookup for Date/DateTimeRange containment to cast the rhs to the correct
- type.
- """
- lookup_name = 'contains'
-
- def process_rhs(self, compiler, connection):
- # Transform rhs value for db lookup.
- if isinstance(self.rhs, datetime.date):
- output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField()
- value = models.Value(self.rhs, output_field=output_field)
- self.rhs = value.resolve_expression(compiler.query)
- return super().process_rhs(compiler, connection)
-
- def as_sql(self, compiler, connection):
- lhs, lhs_params = self.process_lhs(compiler, connection)
- rhs, rhs_params = self.process_rhs(compiler, connection)
- params = lhs_params + rhs_params
- # Cast the rhs if needed.
- cast_sql = ''
- if (
- isinstance(self.rhs, models.Expression) and
- self.rhs._output_field_or_none and
- # Skip cast if rhs has a matching range type.
- not isinstance(self.rhs._output_field_or_none, self.lhs.output_field.__class__)
- ):
- cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
- cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type))
- return '%s @> %s%s' % (lhs, rhs, cast_sql), params
-
-
- DateRangeField.register_lookup(DateTimeRangeContains)
- DateTimeRangeField.register_lookup(DateTimeRangeContains)
-
-
- class RangeContainedBy(models.Lookup):
- lookup_name = 'contained_by'
- type_mapping = {
- 'integer': 'int4range',
- 'bigint': 'int8range',
- 'double precision': 'numrange',
- 'date': 'daterange',
- 'timestamp with time zone': 'tstzrange',
- }
-
- def as_sql(self, qn, connection):
- field = self.lhs.output_field
- if isinstance(field, models.FloatField):
- sql = '%s::numeric <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
- else:
- sql = '%s <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
- lhs, lhs_params = self.process_lhs(qn, connection)
- rhs, rhs_params = self.process_rhs(qn, connection)
- params = lhs_params + rhs_params
- return sql % (lhs, rhs), params
-
- def get_prep_lookup(self):
- return RangeField().get_prep_value(self.rhs)
-
-
- models.DateField.register_lookup(RangeContainedBy)
- models.DateTimeField.register_lookup(RangeContainedBy)
- models.IntegerField.register_lookup(RangeContainedBy)
- models.BigIntegerField.register_lookup(RangeContainedBy)
- models.FloatField.register_lookup(RangeContainedBy)
-
-
- @RangeField.register_lookup
- class FullyLessThan(lookups.PostgresSimpleLookup):
- lookup_name = 'fully_lt'
- operator = '<<'
-
-
- @RangeField.register_lookup
- class FullGreaterThan(lookups.PostgresSimpleLookup):
- lookup_name = 'fully_gt'
- operator = '>>'
-
-
- @RangeField.register_lookup
- class NotLessThan(lookups.PostgresSimpleLookup):
- lookup_name = 'not_lt'
- operator = '&>'
-
-
- @RangeField.register_lookup
- class NotGreaterThan(lookups.PostgresSimpleLookup):
- lookup_name = 'not_gt'
- operator = '&<'
-
-
- @RangeField.register_lookup
- class AdjacentToLookup(lookups.PostgresSimpleLookup):
- lookup_name = 'adjacent_to'
- operator = '-|-'
-
-
- @RangeField.register_lookup
- class RangeStartsWith(models.Transform):
- lookup_name = 'startswith'
- function = 'lower'
-
- @property
- def output_field(self):
- return self.lhs.output_field.base_field
-
-
- @RangeField.register_lookup
- class RangeEndsWith(models.Transform):
- lookup_name = 'endswith'
- function = 'upper'
-
- @property
- def output_field(self):
- return self.lhs.output_field.base_field
-
-
- @RangeField.register_lookup
- class IsEmpty(models.Transform):
- lookup_name = 'isempty'
- function = 'isempty'
- output_field = models.BooleanField()
|