Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ranges.py 7.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import datetime
  2. import json
  3. from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
  4. from django.contrib.postgres import forms, lookups
  5. from django.db import models
  6. from .utils import AttributeSetter
  7. __all__ = [
  8. 'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
  9. 'FloatRangeField', 'DateTimeRangeField', 'DateRangeField',
  10. ]
  11. class RangeField(models.Field):
  12. empty_strings_allowed = False
  13. def __init__(self, *args, **kwargs):
  14. # Initializing base_field here ensures that its model matches the model for self.
  15. if hasattr(self, 'base_field'):
  16. self.base_field = self.base_field()
  17. super().__init__(*args, **kwargs)
  18. @property
  19. def model(self):
  20. try:
  21. return self.__dict__['model']
  22. except KeyError:
  23. raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
  24. @model.setter
  25. def model(self, model):
  26. self.__dict__['model'] = model
  27. self.base_field.model = model
  28. def get_prep_value(self, value):
  29. if value is None:
  30. return None
  31. elif isinstance(value, Range):
  32. return value
  33. elif isinstance(value, (list, tuple)):
  34. return self.range_type(value[0], value[1])
  35. return value
  36. def to_python(self, value):
  37. if isinstance(value, str):
  38. # Assume we're deserializing
  39. vals = json.loads(value)
  40. for end in ('lower', 'upper'):
  41. if end in vals:
  42. vals[end] = self.base_field.to_python(vals[end])
  43. value = self.range_type(**vals)
  44. elif isinstance(value, (list, tuple)):
  45. value = self.range_type(value[0], value[1])
  46. return value
  47. def set_attributes_from_name(self, name):
  48. super().set_attributes_from_name(name)
  49. self.base_field.set_attributes_from_name(name)
  50. def value_to_string(self, obj):
  51. value = self.value_from_object(obj)
  52. if value is None:
  53. return None
  54. if value.isempty:
  55. return json.dumps({"empty": True})
  56. base_field = self.base_field
  57. result = {"bounds": value._bounds}
  58. for end in ('lower', 'upper'):
  59. val = getattr(value, end)
  60. if val is None:
  61. result[end] = None
  62. else:
  63. obj = AttributeSetter(base_field.attname, val)
  64. result[end] = base_field.value_to_string(obj)
  65. return json.dumps(result)
  66. def formfield(self, **kwargs):
  67. kwargs.setdefault('form_class', self.form_field)
  68. return super().formfield(**kwargs)
  69. class IntegerRangeField(RangeField):
  70. base_field = models.IntegerField
  71. range_type = NumericRange
  72. form_field = forms.IntegerRangeField
  73. def db_type(self, connection):
  74. return 'int4range'
  75. class BigIntegerRangeField(RangeField):
  76. base_field = models.BigIntegerField
  77. range_type = NumericRange
  78. form_field = forms.IntegerRangeField
  79. def db_type(self, connection):
  80. return 'int8range'
  81. class FloatRangeField(RangeField):
  82. base_field = models.FloatField
  83. range_type = NumericRange
  84. form_field = forms.FloatRangeField
  85. def db_type(self, connection):
  86. return 'numrange'
  87. class DateTimeRangeField(RangeField):
  88. base_field = models.DateTimeField
  89. range_type = DateTimeTZRange
  90. form_field = forms.DateTimeRangeField
  91. def db_type(self, connection):
  92. return 'tstzrange'
  93. class DateRangeField(RangeField):
  94. base_field = models.DateField
  95. range_type = DateRange
  96. form_field = forms.DateRangeField
  97. def db_type(self, connection):
  98. return 'daterange'
  99. RangeField.register_lookup(lookups.DataContains)
  100. RangeField.register_lookup(lookups.ContainedBy)
  101. RangeField.register_lookup(lookups.Overlap)
  102. class DateTimeRangeContains(models.Lookup):
  103. """
  104. Lookup for Date/DateTimeRange containment to cast the rhs to the correct
  105. type.
  106. """
  107. lookup_name = 'contains'
  108. def process_rhs(self, compiler, connection):
  109. # Transform rhs value for db lookup.
  110. if isinstance(self.rhs, datetime.date):
  111. output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField()
  112. value = models.Value(self.rhs, output_field=output_field)
  113. self.rhs = value.resolve_expression(compiler.query)
  114. return super().process_rhs(compiler, connection)
  115. def as_sql(self, compiler, connection):
  116. lhs, lhs_params = self.process_lhs(compiler, connection)
  117. rhs, rhs_params = self.process_rhs(compiler, connection)
  118. params = lhs_params + rhs_params
  119. # Cast the rhs if needed.
  120. cast_sql = ''
  121. if isinstance(self.rhs, models.Expression) and self.rhs._output_field_or_none:
  122. cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
  123. cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type))
  124. return '%s @> %s%s' % (lhs, rhs, cast_sql), params
  125. DateRangeField.register_lookup(DateTimeRangeContains)
  126. DateTimeRangeField.register_lookup(DateTimeRangeContains)
  127. class RangeContainedBy(models.Lookup):
  128. lookup_name = 'contained_by'
  129. type_mapping = {
  130. 'integer': 'int4range',
  131. 'bigint': 'int8range',
  132. 'double precision': 'numrange',
  133. 'date': 'daterange',
  134. 'timestamp with time zone': 'tstzrange',
  135. }
  136. def as_sql(self, qn, connection):
  137. field = self.lhs.output_field
  138. if isinstance(field, models.FloatField):
  139. sql = '%s::numeric <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
  140. else:
  141. sql = '%s <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
  142. lhs, lhs_params = self.process_lhs(qn, connection)
  143. rhs, rhs_params = self.process_rhs(qn, connection)
  144. params = lhs_params + rhs_params
  145. return sql % (lhs, rhs), params
  146. def get_prep_lookup(self):
  147. return RangeField().get_prep_value(self.rhs)
  148. models.DateField.register_lookup(RangeContainedBy)
  149. models.DateTimeField.register_lookup(RangeContainedBy)
  150. models.IntegerField.register_lookup(RangeContainedBy)
  151. models.BigIntegerField.register_lookup(RangeContainedBy)
  152. models.FloatField.register_lookup(RangeContainedBy)
  153. @RangeField.register_lookup
  154. class FullyLessThan(lookups.PostgresSimpleLookup):
  155. lookup_name = 'fully_lt'
  156. operator = '<<'
  157. @RangeField.register_lookup
  158. class FullGreaterThan(lookups.PostgresSimpleLookup):
  159. lookup_name = 'fully_gt'
  160. operator = '>>'
  161. @RangeField.register_lookup
  162. class NotLessThan(lookups.PostgresSimpleLookup):
  163. lookup_name = 'not_lt'
  164. operator = '&>'
  165. @RangeField.register_lookup
  166. class NotGreaterThan(lookups.PostgresSimpleLookup):
  167. lookup_name = 'not_gt'
  168. operator = '&<'
  169. @RangeField.register_lookup
  170. class AdjacentToLookup(lookups.PostgresSimpleLookup):
  171. lookup_name = 'adjacent_to'
  172. operator = '-|-'
  173. @RangeField.register_lookup
  174. class RangeStartsWith(models.Transform):
  175. lookup_name = 'startswith'
  176. function = 'lower'
  177. @property
  178. def output_field(self):
  179. return self.lhs.output_field.base_field
  180. @RangeField.register_lookup
  181. class RangeEndsWith(models.Transform):
  182. lookup_name = 'endswith'
  183. function = 'upper'
  184. @property
  185. def output_field(self):
  186. return self.lhs.output_field.base_field
  187. @RangeField.register_lookup
  188. class IsEmpty(models.Transform):
  189. lookup_name = 'isempty'
  190. function = 'isempty'
  191. output_field = models.BooleanField()