You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ranges.py 8.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. import datetime
  2. import json
  3. from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
  4. from django.contrib.postgres import forms, lookups
  5. from django.db import models
  6. from .utils import AttributeSetter
  7. __all__ = [
  8. 'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
  9. 'DecimalRangeField', 'DateTimeRangeField', 'DateRangeField',
  10. 'FloatRangeField',
  11. ]
  12. class RangeField(models.Field):
  13. empty_strings_allowed = False
  14. def __init__(self, *args, **kwargs):
  15. # Initializing base_field here ensures that its model matches the model for self.
  16. if hasattr(self, 'base_field'):
  17. self.base_field = self.base_field()
  18. super().__init__(*args, **kwargs)
  19. @property
  20. def model(self):
  21. try:
  22. return self.__dict__['model']
  23. except KeyError:
  24. raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
  25. @model.setter
  26. def model(self, model):
  27. self.__dict__['model'] = model
  28. self.base_field.model = model
  29. def get_prep_value(self, value):
  30. if value is None:
  31. return None
  32. elif isinstance(value, Range):
  33. return value
  34. elif isinstance(value, (list, tuple)):
  35. return self.range_type(value[0], value[1])
  36. return value
  37. def to_python(self, value):
  38. if isinstance(value, str):
  39. # Assume we're deserializing
  40. vals = json.loads(value)
  41. for end in ('lower', 'upper'):
  42. if end in vals:
  43. vals[end] = self.base_field.to_python(vals[end])
  44. value = self.range_type(**vals)
  45. elif isinstance(value, (list, tuple)):
  46. value = self.range_type(value[0], value[1])
  47. return value
  48. def set_attributes_from_name(self, name):
  49. super().set_attributes_from_name(name)
  50. self.base_field.set_attributes_from_name(name)
  51. def value_to_string(self, obj):
  52. value = self.value_from_object(obj)
  53. if value is None:
  54. return None
  55. if value.isempty:
  56. return json.dumps({"empty": True})
  57. base_field = self.base_field
  58. result = {"bounds": value._bounds}
  59. for end in ('lower', 'upper'):
  60. val = getattr(value, end)
  61. if val is None:
  62. result[end] = None
  63. else:
  64. obj = AttributeSetter(base_field.attname, val)
  65. result[end] = base_field.value_to_string(obj)
  66. return json.dumps(result)
  67. def formfield(self, **kwargs):
  68. kwargs.setdefault('form_class', self.form_field)
  69. return super().formfield(**kwargs)
  70. class IntegerRangeField(RangeField):
  71. base_field = models.IntegerField
  72. range_type = NumericRange
  73. form_field = forms.IntegerRangeField
  74. def db_type(self, connection):
  75. return 'int4range'
  76. class BigIntegerRangeField(RangeField):
  77. base_field = models.BigIntegerField
  78. range_type = NumericRange
  79. form_field = forms.IntegerRangeField
  80. def db_type(self, connection):
  81. return 'int8range'
  82. class DecimalRangeField(RangeField):
  83. base_field = models.DecimalField
  84. range_type = NumericRange
  85. form_field = forms.DecimalRangeField
  86. def db_type(self, connection):
  87. return 'numrange'
  88. class FloatRangeField(RangeField):
  89. system_check_deprecated_details = {
  90. 'msg': (
  91. 'FloatRangeField is deprecated and will be removed in Django 3.1.'
  92. ),
  93. 'hint': 'Use DecimalRangeField instead.',
  94. 'id': 'fields.W902',
  95. }
  96. base_field = models.FloatField
  97. range_type = NumericRange
  98. form_field = forms.FloatRangeField
  99. def db_type(self, connection):
  100. return 'numrange'
  101. class DateTimeRangeField(RangeField):
  102. base_field = models.DateTimeField
  103. range_type = DateTimeTZRange
  104. form_field = forms.DateTimeRangeField
  105. def db_type(self, connection):
  106. return 'tstzrange'
  107. class DateRangeField(RangeField):
  108. base_field = models.DateField
  109. range_type = DateRange
  110. form_field = forms.DateRangeField
  111. def db_type(self, connection):
  112. return 'daterange'
  113. RangeField.register_lookup(lookups.DataContains)
  114. RangeField.register_lookup(lookups.ContainedBy)
  115. RangeField.register_lookup(lookups.Overlap)
  116. class DateTimeRangeContains(models.Lookup):
  117. """
  118. Lookup for Date/DateTimeRange containment to cast the rhs to the correct
  119. type.
  120. """
  121. lookup_name = 'contains'
  122. def process_rhs(self, compiler, connection):
  123. # Transform rhs value for db lookup.
  124. if isinstance(self.rhs, datetime.date):
  125. output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField()
  126. value = models.Value(self.rhs, output_field=output_field)
  127. self.rhs = value.resolve_expression(compiler.query)
  128. return super().process_rhs(compiler, connection)
  129. def as_sql(self, compiler, connection):
  130. lhs, lhs_params = self.process_lhs(compiler, connection)
  131. rhs, rhs_params = self.process_rhs(compiler, connection)
  132. params = lhs_params + rhs_params
  133. # Cast the rhs if needed.
  134. cast_sql = ''
  135. if (
  136. isinstance(self.rhs, models.Expression) and
  137. self.rhs._output_field_or_none and
  138. # Skip cast if rhs has a matching range type.
  139. not isinstance(self.rhs._output_field_or_none, self.lhs.output_field.__class__)
  140. ):
  141. cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
  142. cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type))
  143. return '%s @> %s%s' % (lhs, rhs, cast_sql), params
  144. DateRangeField.register_lookup(DateTimeRangeContains)
  145. DateTimeRangeField.register_lookup(DateTimeRangeContains)
  146. class RangeContainedBy(models.Lookup):
  147. lookup_name = 'contained_by'
  148. type_mapping = {
  149. 'integer': 'int4range',
  150. 'bigint': 'int8range',
  151. 'double precision': 'numrange',
  152. 'date': 'daterange',
  153. 'timestamp with time zone': 'tstzrange',
  154. }
  155. def as_sql(self, qn, connection):
  156. field = self.lhs.output_field
  157. if isinstance(field, models.FloatField):
  158. sql = '%s::numeric <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
  159. else:
  160. sql = '%s <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
  161. lhs, lhs_params = self.process_lhs(qn, connection)
  162. rhs, rhs_params = self.process_rhs(qn, connection)
  163. params = lhs_params + rhs_params
  164. return sql % (lhs, rhs), params
  165. def get_prep_lookup(self):
  166. return RangeField().get_prep_value(self.rhs)
  167. models.DateField.register_lookup(RangeContainedBy)
  168. models.DateTimeField.register_lookup(RangeContainedBy)
  169. models.IntegerField.register_lookup(RangeContainedBy)
  170. models.BigIntegerField.register_lookup(RangeContainedBy)
  171. models.FloatField.register_lookup(RangeContainedBy)
  172. @RangeField.register_lookup
  173. class FullyLessThan(lookups.PostgresSimpleLookup):
  174. lookup_name = 'fully_lt'
  175. operator = '<<'
  176. @RangeField.register_lookup
  177. class FullGreaterThan(lookups.PostgresSimpleLookup):
  178. lookup_name = 'fully_gt'
  179. operator = '>>'
  180. @RangeField.register_lookup
  181. class NotLessThan(lookups.PostgresSimpleLookup):
  182. lookup_name = 'not_lt'
  183. operator = '&>'
  184. @RangeField.register_lookup
  185. class NotGreaterThan(lookups.PostgresSimpleLookup):
  186. lookup_name = 'not_gt'
  187. operator = '&<'
  188. @RangeField.register_lookup
  189. class AdjacentToLookup(lookups.PostgresSimpleLookup):
  190. lookup_name = 'adjacent_to'
  191. operator = '-|-'
  192. @RangeField.register_lookup
  193. class RangeStartsWith(models.Transform):
  194. lookup_name = 'startswith'
  195. function = 'lower'
  196. @property
  197. def output_field(self):
  198. return self.lhs.output_field.base_field
  199. @RangeField.register_lookup
  200. class RangeEndsWith(models.Transform):
  201. lookup_name = 'endswith'
  202. function = 'upper'
  203. @property
  204. def output_field(self):
  205. return self.lhs.output_field.base_field
  206. @RangeField.register_lookup
  207. class IsEmpty(models.Transform):
  208. lookup_name = 'isempty'
  209. function = 'isempty'
  210. output_field = models.BooleanField()