Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

aggregates.py 7.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. """
  2. Classes to represent the definitions of aggregate functions.
  3. """
  4. from django.core.exceptions import FieldError
  5. from django.db.models.expressions import Case, Func, Star, When
  6. from django.db.models.fields import DecimalField, FloatField, IntegerField
  7. __all__ = [
  8. 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
  9. ]
  10. class Aggregate(Func):
  11. contains_aggregate = True
  12. name = None
  13. filter_template = '%s FILTER (WHERE %%(filter)s)'
  14. window_compatible = True
  15. def __init__(self, *args, filter=None, **kwargs):
  16. self.filter = filter
  17. super().__init__(*args, **kwargs)
  18. def get_source_fields(self):
  19. # Don't return the filter expression since it's not a source field.
  20. return [e._output_field_or_none for e in super().get_source_expressions()]
  21. def get_source_expressions(self):
  22. source_expressions = super().get_source_expressions()
  23. if self.filter:
  24. source_expressions += [self.filter]
  25. return source_expressions
  26. def set_source_expressions(self, exprs):
  27. self.filter = self.filter and exprs.pop()
  28. return super().set_source_expressions(exprs)
  29. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  30. # Aggregates are not allowed in UPDATE queries, so ignore for_save
  31. c = super().resolve_expression(query, allow_joins, reuse, summarize)
  32. c.filter = c.filter and c.filter.resolve_expression(query, allow_joins, reuse, summarize)
  33. if not summarize:
  34. # Call Aggregate.get_source_expressions() to avoid
  35. # returning self.filter and including that in this loop.
  36. expressions = super(Aggregate, c).get_source_expressions()
  37. for index, expr in enumerate(expressions):
  38. if expr.contains_aggregate:
  39. before_resolved = self.get_source_expressions()[index]
  40. name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
  41. raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
  42. return c
  43. @property
  44. def default_alias(self):
  45. expressions = self.get_source_expressions()
  46. if len(expressions) == 1 and hasattr(expressions[0], 'name'):
  47. return '%s__%s' % (expressions[0].name, self.name.lower())
  48. raise TypeError("Complex expressions require an alias")
  49. def get_group_by_cols(self):
  50. return []
  51. def as_sql(self, compiler, connection, **extra_context):
  52. if self.filter:
  53. if connection.features.supports_aggregate_filter_clause:
  54. filter_sql, filter_params = self.filter.as_sql(compiler, connection)
  55. template = self.filter_template % extra_context.get('template', self.template)
  56. sql, params = super().as_sql(compiler, connection, template=template, filter=filter_sql)
  57. return sql, params + filter_params
  58. else:
  59. copy = self.copy()
  60. copy.filter = None
  61. source_expressions = copy.get_source_expressions()
  62. condition = When(self.filter, then=source_expressions[0])
  63. copy.set_source_expressions([Case(condition)] + source_expressions[1:])
  64. return super(Aggregate, copy).as_sql(compiler, connection, **extra_context)
  65. return super().as_sql(compiler, connection, **extra_context)
  66. def _get_repr_options(self):
  67. options = super()._get_repr_options()
  68. if self.filter:
  69. options.update({'filter': self.filter})
  70. return options
  71. class Avg(Aggregate):
  72. function = 'AVG'
  73. name = 'Avg'
  74. def _resolve_output_field(self):
  75. source_field = self.get_source_fields()[0]
  76. if isinstance(source_field, (IntegerField, DecimalField)):
  77. return FloatField()
  78. return super()._resolve_output_field()
  79. def as_mysql(self, compiler, connection):
  80. sql, params = super().as_sql(compiler, connection)
  81. if self.output_field.get_internal_type() == 'DurationField':
  82. sql = 'CAST(%s as SIGNED)' % sql
  83. return sql, params
  84. def as_oracle(self, compiler, connection):
  85. if self.output_field.get_internal_type() == 'DurationField':
  86. expression = self.get_source_expressions()[0]
  87. from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
  88. return compiler.compile(
  89. SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter))
  90. )
  91. return super().as_sql(compiler, connection)
  92. class Count(Aggregate):
  93. function = 'COUNT'
  94. name = 'Count'
  95. template = '%(function)s(%(distinct)s%(expressions)s)'
  96. output_field = IntegerField()
  97. def __init__(self, expression, distinct=False, filter=None, **extra):
  98. if expression == '*':
  99. expression = Star()
  100. if isinstance(expression, Star) and filter is not None:
  101. raise ValueError('Star cannot be used with filter. Please specify a field.')
  102. super().__init__(
  103. expression, distinct='DISTINCT ' if distinct else '',
  104. filter=filter, **extra
  105. )
  106. def _get_repr_options(self):
  107. return {**super()._get_repr_options(), 'distinct': self.extra['distinct'] != ''}
  108. def convert_value(self, value, expression, connection):
  109. return 0 if value is None else value
  110. class Max(Aggregate):
  111. function = 'MAX'
  112. name = 'Max'
  113. class Min(Aggregate):
  114. function = 'MIN'
  115. name = 'Min'
  116. class StdDev(Aggregate):
  117. name = 'StdDev'
  118. output_field = FloatField()
  119. def __init__(self, expression, sample=False, **extra):
  120. self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
  121. super().__init__(expression, **extra)
  122. def _get_repr_options(self):
  123. return {**super()._get_repr_options(), 'sample': self.function == 'STDDEV_SAMP'}
  124. class Sum(Aggregate):
  125. function = 'SUM'
  126. name = 'Sum'
  127. def as_mysql(self, compiler, connection):
  128. sql, params = super().as_sql(compiler, connection)
  129. if self.output_field.get_internal_type() == 'DurationField':
  130. sql = 'CAST(%s as SIGNED)' % sql
  131. return sql, params
  132. def as_oracle(self, compiler, connection):
  133. if self.output_field.get_internal_type() == 'DurationField':
  134. expression = self.get_source_expressions()[0]
  135. from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
  136. return compiler.compile(
  137. SecondsToInterval(Sum(IntervalToSeconds(expression)))
  138. )
  139. return super().as_sql(compiler, connection)
  140. class Variance(Aggregate):
  141. name = 'Variance'
  142. output_field = FloatField()
  143. def __init__(self, expression, sample=False, **extra):
  144. self.function = 'VAR_SAMP' if sample else 'VAR_POP'
  145. super().__init__(expression, **extra)
  146. def _get_repr_options(self):
  147. return {**super()._get_repr_options(), 'sample': self.function == 'VAR_SAMP'}