|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358 |
- import copy
- import datetime
- import inspect
- from decimal import Decimal
-
- from django.core.exceptions import EmptyResultSet, FieldError
- from django.db import connection
- from django.db.models import fields
- from django.db.models.query_utils import Q
- from django.utils.deconstruct import deconstructible
- from django.utils.functional import cached_property
- from django.utils.hashable import make_hashable
-
-
- class SQLiteNumericMixin:
- """
- Some expressions with output_field=DecimalField() must be cast to
- numeric to be properly filtered.
- """
- def as_sqlite(self, compiler, connection, **extra_context):
- sql, params = self.as_sql(compiler, connection, **extra_context)
- try:
- if self.output_field.get_internal_type() == 'DecimalField':
- sql = 'CAST(%s AS NUMERIC)' % sql
- except FieldError:
- pass
- return sql, params
-
-
- class Combinable:
- """
- Provide the ability to combine one or two objects with
- some connector. For example F('foo') + F('bar').
- """
-
- # Arithmetic connectors
- ADD = '+'
- SUB = '-'
- MUL = '*'
- DIV = '/'
- POW = '^'
- # The following is a quoted % operator - it is quoted because it can be
- # used in strings that also have parameter substitution.
- MOD = '%%'
-
- # Bitwise operators - note that these are generated by .bitand()
- # and .bitor(), the '&' and '|' are reserved for boolean operator
- # usage.
- BITAND = '&'
- BITOR = '|'
- BITLEFTSHIFT = '<<'
- BITRIGHTSHIFT = '>>'
-
- def _combine(self, other, connector, reversed):
- if not hasattr(other, 'resolve_expression'):
- # everything must be resolvable to an expression
- if isinstance(other, datetime.timedelta):
- other = DurationValue(other, output_field=fields.DurationField())
- else:
- other = Value(other)
-
- if reversed:
- return CombinedExpression(other, connector, self)
- return CombinedExpression(self, connector, other)
-
- #############
- # OPERATORS #
- #############
-
- def __neg__(self):
- return self._combine(-1, self.MUL, False)
-
- def __add__(self, other):
- return self._combine(other, self.ADD, False)
-
- def __sub__(self, other):
- return self._combine(other, self.SUB, False)
-
- def __mul__(self, other):
- return self._combine(other, self.MUL, False)
-
- def __truediv__(self, other):
- return self._combine(other, self.DIV, False)
-
- def __mod__(self, other):
- return self._combine(other, self.MOD, False)
-
- def __pow__(self, other):
- return self._combine(other, self.POW, False)
-
- def __and__(self, other):
- raise NotImplementedError(
- "Use .bitand() and .bitor() for bitwise logical operations."
- )
-
- def bitand(self, other):
- return self._combine(other, self.BITAND, False)
-
- def bitleftshift(self, other):
- return self._combine(other, self.BITLEFTSHIFT, False)
-
- def bitrightshift(self, other):
- return self._combine(other, self.BITRIGHTSHIFT, False)
-
- def __or__(self, other):
- raise NotImplementedError(
- "Use .bitand() and .bitor() for bitwise logical operations."
- )
-
- def bitor(self, other):
- return self._combine(other, self.BITOR, False)
-
- def __radd__(self, other):
- return self._combine(other, self.ADD, True)
-
- def __rsub__(self, other):
- return self._combine(other, self.SUB, True)
-
- def __rmul__(self, other):
- return self._combine(other, self.MUL, True)
-
- def __rtruediv__(self, other):
- return self._combine(other, self.DIV, True)
-
- def __rmod__(self, other):
- return self._combine(other, self.MOD, True)
-
- def __rpow__(self, other):
- return self._combine(other, self.POW, True)
-
- def __rand__(self, other):
- raise NotImplementedError(
- "Use .bitand() and .bitor() for bitwise logical operations."
- )
-
- def __ror__(self, other):
- raise NotImplementedError(
- "Use .bitand() and .bitor() for bitwise logical operations."
- )
-
-
- @deconstructible
- class BaseExpression:
- """Base class for all query expressions."""
-
- # aggregate specific fields
- is_summary = False
- _output_field_resolved_to_none = False
- # Can the expression be used in a WHERE clause?
- filterable = True
- # Can the expression can be used as a source expression in Window?
- window_compatible = False
-
- def __init__(self, output_field=None):
- if output_field is not None:
- self.output_field = output_field
-
- def __getstate__(self):
- state = self.__dict__.copy()
- state.pop('convert_value', None)
- return state
-
- def get_db_converters(self, connection):
- return (
- []
- if self.convert_value is self._convert_value_noop else
- [self.convert_value]
- ) + self.output_field.get_db_converters(connection)
-
- def get_source_expressions(self):
- return []
-
- def set_source_expressions(self, exprs):
- assert not exprs
-
- def _parse_expressions(self, *expressions):
- return [
- arg if hasattr(arg, 'resolve_expression') else (
- F(arg) if isinstance(arg, str) else Value(arg)
- ) for arg in expressions
- ]
-
- def as_sql(self, compiler, connection):
- """
- Responsible for returning a (sql, [params]) tuple to be included
- in the current query.
-
- Different backends can provide their own implementation, by
- providing an `as_{vendor}` method and patching the Expression:
-
- ```
- def override_as_sql(self, compiler, connection):
- # custom logic
- return super().as_sql(compiler, connection)
- setattr(Expression, 'as_' + connection.vendor, override_as_sql)
- ```
-
- Arguments:
- * compiler: the query compiler responsible for generating the query.
- Must have a compile method, returning a (sql, [params]) tuple.
- Calling compiler(value) will return a quoted `value`.
-
- * connection: the database connection used for the current query.
-
- Return: (sql, params)
- Where `sql` is a string containing ordered sql parameters to be
- replaced with the elements of the list `params`.
- """
- raise NotImplementedError("Subclasses must implement as_sql()")
-
- @cached_property
- def contains_aggregate(self):
- return any(expr and expr.contains_aggregate for expr in self.get_source_expressions())
-
- @cached_property
- def contains_over_clause(self):
- return any(expr and expr.contains_over_clause for expr in self.get_source_expressions())
-
- @cached_property
- def contains_column_references(self):
- return any(expr and expr.contains_column_references for expr in self.get_source_expressions())
-
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
- """
- Provide the chance to do any preprocessing or validation before being
- added to the query.
-
- Arguments:
- * query: the backend query implementation
- * allow_joins: boolean allowing or denying use of joins
- in this query
- * reuse: a set of reusable joins for multijoins
- * summarize: a terminal aggregate clause
- * for_save: whether this expression about to be used in a save or update
-
- Return: an Expression to be added to the query.
- """
- c = self.copy()
- c.is_summary = summarize
- c.set_source_expressions([
- expr.resolve_expression(query, allow_joins, reuse, summarize)
- if expr else None
- for expr in c.get_source_expressions()
- ])
- return c
-
- def _prepare(self, field):
- """Hook used by Lookup.get_prep_lookup() to do custom preparation."""
- return self
-
- @property
- def field(self):
- return self.output_field
-
- @cached_property
- def output_field(self):
- """Return the output type of this expressions."""
- output_field = self._resolve_output_field()
- if output_field is None:
- self._output_field_resolved_to_none = True
- raise FieldError('Cannot resolve expression type, unknown output_field')
- return output_field
-
- @cached_property
- def _output_field_or_none(self):
- """
- Return the output field of this expression, or None if
- _resolve_output_field() didn't return an output type.
- """
- try:
- return self.output_field
- except FieldError:
- if not self._output_field_resolved_to_none:
- raise
-
- def _resolve_output_field(self):
- """
- Attempt to infer the output type of the expression. If the output
- fields of all source fields match then, simply infer the same type
- here. This isn't always correct, but it makes sense most of the time.
-
- Consider the difference between `2 + 2` and `2 / 3`. Inferring
- the type here is a convenience for the common case. The user should
- supply their own output_field with more complex computations.
-
- If a source's output field resolves to None, exclude it from this check.
- If all sources are None, then an error is raised higher up the stack in
- the output_field property.
- """
- sources_iter = (source for source in self.get_source_fields() if source is not None)
- for output_field in sources_iter:
- if any(not isinstance(output_field, source.__class__) for source in sources_iter):
- raise FieldError('Expression contains mixed types. You must set output_field.')
- return output_field
-
- @staticmethod
- def _convert_value_noop(value, expression, connection):
- return value
-
- @cached_property
- def convert_value(self):
- """
- Expressions provide their own converters because users have the option
- of manually specifying the output_field which may be a different type
- from the one the database returns.
- """
- field = self.output_field
- internal_type = field.get_internal_type()
- if internal_type == 'FloatField':
- return lambda value, expression, connection: None if value is None else float(value)
- elif internal_type.endswith('IntegerField'):
- return lambda value, expression, connection: None if value is None else int(value)
- elif internal_type == 'DecimalField':
- return lambda value, expression, connection: None if value is None else Decimal(value)
- return self._convert_value_noop
-
- def get_lookup(self, lookup):
- return self.output_field.get_lookup(lookup)
-
- def get_transform(self, name):
- return self.output_field.get_transform(name)
-
- def relabeled_clone(self, change_map):
- clone = self.copy()
- clone.set_source_expressions([
- e.relabeled_clone(change_map) if e is not None else None
- for e in self.get_source_expressions()
- ])
- return clone
-
- def copy(self):
- return copy.copy(self)
-
- def get_group_by_cols(self):
- if not self.contains_aggregate:
- return [self]
- cols = []
- for source in self.get_source_expressions():
- cols.extend(source.get_group_by_cols())
- return cols
-
- def get_source_fields(self):
- """Return the underlying field types used by this aggregate."""
- return [e._output_field_or_none for e in self.get_source_expressions()]
-
- def asc(self, **kwargs):
- return OrderBy(self, **kwargs)
-
- def desc(self, **kwargs):
- return OrderBy(self, descending=True, **kwargs)
-
- def reverse_ordering(self):
- return self
-
- def flatten(self):
- """
- Recursively yield this expression and all subexpressions, in
- depth-first order.
- """
- yield self
- for expr in self.get_source_expressions():
- if expr:
- yield from expr.flatten()
-
- @cached_property
- def identity(self):
- constructor_signature = inspect.signature(self.__init__)
- args, kwargs = self._constructor_args
- signature = constructor_signature.bind_partial(*args, **kwargs)
- signature.apply_defaults()
- arguments = signature.arguments.items()
- identity = [self.__class__]
- for arg, value in arguments:
- if isinstance(value, fields.Field):
- if value.name and value.model:
- value = (value.model._meta.label, value.name)
- else:
- value = type(value)
- else:
- value = make_hashable(value)
- identity.append((arg, value))
- return tuple(identity)
-
- def __eq__(self, other):
- return isinstance(other, BaseExpression) and other.identity == self.identity
-
- def __hash__(self):
- return hash(self.identity)
-
-
- class Expression(BaseExpression, Combinable):
- """An expression that can be combined with other expressions."""
- pass
-
-
- class CombinedExpression(SQLiteNumericMixin, Expression):
-
- def __init__(self, lhs, connector, rhs, output_field=None):
- super().__init__(output_field=output_field)
- self.connector = connector
- self.lhs = lhs
- self.rhs = rhs
-
- def __repr__(self):
- return "<{}: {}>".format(self.__class__.__name__, self)
-
- def __str__(self):
- return "{} {} {}".format(self.lhs, self.connector, self.rhs)
-
- def get_source_expressions(self):
- return [self.lhs, self.rhs]
-
- def set_source_expressions(self, exprs):
- self.lhs, self.rhs = exprs
-
- def as_sql(self, compiler, connection):
- try:
- lhs_output = self.lhs.output_field
- except FieldError:
- lhs_output = None
- try:
- rhs_output = self.rhs.output_field
- except FieldError:
- rhs_output = None
- if (not connection.features.has_native_duration_field and
- ((lhs_output and lhs_output.get_internal_type() == 'DurationField') or
- (rhs_output and rhs_output.get_internal_type() == 'DurationField'))):
- return DurationExpression(self.lhs, self.connector, self.rhs).as_sql(compiler, connection)
- if (lhs_output and rhs_output and self.connector == self.SUB and
- lhs_output.get_internal_type() in {'DateField', 'DateTimeField', 'TimeField'} and
- lhs_output.get_internal_type() == rhs_output.get_internal_type()):
- return TemporalSubtraction(self.lhs, self.rhs).as_sql(compiler, connection)
- expressions = []
- expression_params = []
- sql, params = compiler.compile(self.lhs)
- expressions.append(sql)
- expression_params.extend(params)
- sql, params = compiler.compile(self.rhs)
- expressions.append(sql)
- expression_params.extend(params)
- # order of precedence
- expression_wrapper = '(%s)'
- sql = connection.ops.combine_expression(self.connector, expressions)
- return expression_wrapper % sql, expression_params
-
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
- c = self.copy()
- c.is_summary = summarize
- c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
- c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
- return c
-
-
- class DurationExpression(CombinedExpression):
- def compile(self, side, compiler, connection):
- if not isinstance(side, DurationValue):
- try:
- output = side.output_field
- except FieldError:
- pass
- else:
- if output.get_internal_type() == 'DurationField':
- sql, params = compiler.compile(side)
- return connection.ops.format_for_duration_arithmetic(sql), params
- return compiler.compile(side)
-
- def as_sql(self, compiler, connection):
- connection.ops.check_expression_support(self)
- expressions = []
- expression_params = []
- sql, params = self.compile(self.lhs, compiler, connection)
- expressions.append(sql)
- expression_params.extend(params)
- sql, params = self.compile(self.rhs, compiler, connection)
- expressions.append(sql)
- expression_params.extend(params)
- # order of precedence
- expression_wrapper = '(%s)'
- sql = connection.ops.combine_duration_expression(self.connector, expressions)
- return expression_wrapper % sql, expression_params
-
-
- class TemporalSubtraction(CombinedExpression):
- output_field = fields.DurationField()
-
- def __init__(self, lhs, rhs):
- super().__init__(lhs, self.SUB, rhs)
-
- def as_sql(self, compiler, connection):
- connection.ops.check_expression_support(self)
- lhs = compiler.compile(self.lhs, connection)
- rhs = compiler.compile(self.rhs, connection)
- return connection.ops.subtract_temporals(self.lhs.output_field.get_internal_type(), lhs, rhs)
-
-
- @deconstructible
- class F(Combinable):
- """An object capable of resolving references to existing query objects."""
- # Can the expression be used in a WHERE clause?
- filterable = True
-
- def __init__(self, name):
- """
- Arguments:
- * name: the name of the field this expression references
- """
- self.name = name
-
- def __repr__(self):
- return "{}({})".format(self.__class__.__name__, self.name)
-
- def resolve_expression(self, query=None, allow_joins=True, reuse=None,
- summarize=False, for_save=False, simple_col=False):
- return query.resolve_ref(self.name, allow_joins, reuse, summarize, simple_col)
-
- def asc(self, **kwargs):
- return OrderBy(self, **kwargs)
-
- def desc(self, **kwargs):
- return OrderBy(self, descending=True, **kwargs)
-
- def __eq__(self, other):
- return self.__class__ == other.__class__ and self.name == other.name
-
- def __hash__(self):
- return hash(self.name)
-
-
- class ResolvedOuterRef(F):
- """
- An object that contains a reference to an outer query.
-
- In this case, the reference to the outer query has been resolved because
- the inner query has been used as a subquery.
- """
- def as_sql(self, *args, **kwargs):
- raise ValueError(
- 'This queryset contains a reference to an outer query and may '
- 'only be used in a subquery.'
- )
-
- def _prepare(self, output_field=None):
- return self
-
- def relabeled_clone(self, relabels):
- return self
-
-
- class OuterRef(F):
- def resolve_expression(self, query=None, allow_joins=True, reuse=None,
- summarize=False, for_save=False, simple_col=False):
- if isinstance(self.name, self.__class__):
- return self.name
- return ResolvedOuterRef(self.name)
-
- def _prepare(self, output_field=None):
- return self
-
-
- class Func(SQLiteNumericMixin, Expression):
- """An SQL function call."""
- function = None
- template = '%(function)s(%(expressions)s)'
- arg_joiner = ', '
- arity = None # The number of arguments the function accepts.
-
- def __init__(self, *expressions, output_field=None, **extra):
- if self.arity is not None and len(expressions) != self.arity:
- raise TypeError(
- "'%s' takes exactly %s %s (%s given)" % (
- self.__class__.__name__,
- self.arity,
- "argument" if self.arity == 1 else "arguments",
- len(expressions),
- )
- )
- super().__init__(output_field=output_field)
- self.source_expressions = self._parse_expressions(*expressions)
- self.extra = extra
-
- def __repr__(self):
- args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
- extra = {**self.extra, **self._get_repr_options()}
- if extra:
- extra = ', '.join(str(key) + '=' + str(val) for key, val in sorted(extra.items()))
- return "{}({}, {})".format(self.__class__.__name__, args, extra)
- return "{}({})".format(self.__class__.__name__, args)
-
- def _get_repr_options(self):
- """Return a dict of extra __init__() options to include in the repr."""
- return {}
-
- def get_source_expressions(self):
- return self.source_expressions
-
- def set_source_expressions(self, exprs):
- self.source_expressions = exprs
-
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
- c = self.copy()
- c.is_summary = summarize
- for pos, arg in enumerate(c.source_expressions):
- c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize, for_save)
- return c
-
- def as_sql(self, compiler, connection, function=None, template=None, arg_joiner=None, **extra_context):
- connection.ops.check_expression_support(self)
- sql_parts = []
- params = []
- for arg in self.source_expressions:
- arg_sql, arg_params = compiler.compile(arg)
- sql_parts.append(arg_sql)
- params.extend(arg_params)
- data = {**self.extra, **extra_context}
- # Use the first supplied value in this order: the parameter to this
- # method, a value supplied in __init__()'s **extra (the value in
- # `data`), or the value defined on the class.
- if function is not None:
- data['function'] = function
- else:
- data.setdefault('function', self.function)
- template = template or data.get('template', self.template)
- arg_joiner = arg_joiner or data.get('arg_joiner', self.arg_joiner)
- data['expressions'] = data['field'] = arg_joiner.join(sql_parts)
- return template % data, params
-
- def copy(self):
- copy = super().copy()
- copy.source_expressions = self.source_expressions[:]
- copy.extra = self.extra.copy()
- return copy
-
-
- class Value(Expression):
- """Represent a wrapped value as a node within an expression."""
- def __init__(self, value, output_field=None):
- """
- Arguments:
- * value: the value this expression represents. The value will be
- added into the sql parameter list and properly quoted.
-
- * output_field: an instance of the model field type that this
- expression will return, such as IntegerField() or CharField().
- """
- super().__init__(output_field=output_field)
- self.value = value
-
- def __repr__(self):
- return "{}({})".format(self.__class__.__name__, self.value)
-
- def as_sql(self, compiler, connection):
- connection.ops.check_expression_support(self)
- val = self.value
- output_field = self._output_field_or_none
- if output_field is not None:
- if self.for_save:
- val = output_field.get_db_prep_save(val, connection=connection)
- else:
- val = output_field.get_db_prep_value(val, connection=connection)
- if hasattr(output_field, 'get_placeholder'):
- return output_field.get_placeholder(val, compiler, connection), [val]
- if val is None:
- # cx_Oracle does not always convert None to the appropriate
- # NULL type (like in case expressions using numbers), so we
- # use a literal SQL NULL
- return 'NULL', []
- return '%s', [val]
-
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
- c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
- c.for_save = for_save
- return c
-
- def get_group_by_cols(self):
- return []
-
-
- class DurationValue(Value):
- def as_sql(self, compiler, connection):
- connection.ops.check_expression_support(self)
- if connection.features.has_native_duration_field:
- return super().as_sql(compiler, connection)
- return connection.ops.date_interval_sql(self.value), []
-
-
- class RawSQL(Expression):
- def __init__(self, sql, params, output_field=None):
- if output_field is None:
- output_field = fields.Field()
- self.sql, self.params = sql, params
- super().__init__(output_field=output_field)
-
- def __repr__(self):
- return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params)
-
- def as_sql(self, compiler, connection):
- return '(%s)' % self.sql, self.params
-
- def get_group_by_cols(self):
- return [self]
-
-
- class Star(Expression):
- def __repr__(self):
- return "'*'"
-
- def as_sql(self, compiler, connection):
- return '*', []
-
-
- class Random(Expression):
- output_field = fields.FloatField()
-
- def __repr__(self):
- return "Random()"
-
- def as_sql(self, compiler, connection):
- return connection.ops.random_function_sql(), []
-
-
- class Col(Expression):
-
- contains_column_references = True
-
- def __init__(self, alias, target, output_field=None):
- if output_field is None:
- output_field = target
- super().__init__(output_field=output_field)
- self.alias, self.target = alias, target
-
- def __repr__(self):
- return "{}({}, {})".format(
- self.__class__.__name__, self.alias, self.target)
-
- def as_sql(self, compiler, connection):
- qn = compiler.quote_name_unless_alias
- return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
-
- def relabeled_clone(self, relabels):
- return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field)
-
- def get_group_by_cols(self):
- return [self]
-
- def get_db_converters(self, connection):
- if self.target == self.output_field:
- return self.output_field.get_db_converters(connection)
- return (self.output_field.get_db_converters(connection) +
- self.target.get_db_converters(connection))
-
-
- class SimpleCol(Expression):
- """
- Represents the SQL of a column name without the table name.
-
- This variant of Col doesn't include the table name (or an alias) to
- avoid a syntax error in check constraints.
- """
- contains_column_references = True
-
- def __init__(self, target, output_field=None):
- if output_field is None:
- output_field = target
- super().__init__(output_field=output_field)
- self.target = target
-
- def __repr__(self):
- return '{}({})'.format(self.__class__.__name__, self.target)
-
- def as_sql(self, compiler, connection):
- qn = compiler.quote_name_unless_alias
- return qn(self.target.column), []
-
- def get_group_by_cols(self):
- return [self]
-
- def get_db_converters(self, connection):
- if self.target == self.output_field:
- return self.output_field.get_db_converters(connection)
- return (
- self.output_field.get_db_converters(connection) +
- self.target.get_db_converters(connection)
- )
-
-
- class Ref(Expression):
- """
- Reference to column alias of the query. For example, Ref('sum_cost') in
- qs.annotate(sum_cost=Sum('cost')) query.
- """
- def __init__(self, refs, source):
- super().__init__()
- self.refs, self.source = refs, source
-
- def __repr__(self):
- return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source)
-
- def get_source_expressions(self):
- return [self.source]
-
- def set_source_expressions(self, exprs):
- self.source, = exprs
-
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
- # The sub-expression `source` has already been resolved, as this is
- # just a reference to the name of `source`.
- return self
-
- def relabeled_clone(self, relabels):
- return self
-
- def as_sql(self, compiler, connection):
- return connection.ops.quote_name(self.refs), []
-
- def get_group_by_cols(self):
- return [self]
-
-
- class ExpressionList(Func):
- """
- An expression containing multiple expressions. Can be used to provide a
- list of expressions as an argument to another expression, like an
- ordering clause.
- """
- template = '%(expressions)s'
-
- def __init__(self, *expressions, **extra):
- if not expressions:
- raise ValueError('%s requires at least one expression.' % self.__class__.__name__)
- super().__init__(*expressions, **extra)
-
- def __str__(self):
- return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
-
-
- class ExpressionWrapper(Expression):
- """
- An expression that can wrap another expression so that it can provide
- extra context to the inner expression, such as the output_field.
- """
-
- def __init__(self, expression, output_field):
- super().__init__(output_field=output_field)
- self.expression = expression
-
- def set_source_expressions(self, exprs):
- self.expression = exprs[0]
-
- def get_source_expressions(self):
- return [self.expression]
-
- def as_sql(self, compiler, connection):
- return self.expression.as_sql(compiler, connection)
-
- def __repr__(self):
- return "{}({})".format(self.__class__.__name__, self.expression)
-
-
- class When(Expression):
- template = 'WHEN %(condition)s THEN %(result)s'
-
- def __init__(self, condition=None, then=None, **lookups):
- if lookups and condition is None:
- condition, lookups = Q(**lookups), None
- if condition is None or not getattr(condition, 'conditional', False) or lookups:
- raise TypeError("__init__() takes either a Q object or lookups as keyword arguments")
- if isinstance(condition, Q) and not condition:
- raise ValueError("An empty Q() can't be used as a When() condition.")
- super().__init__(output_field=None)
- self.condition = condition
- self.result = self._parse_expressions(then)[0]
-
- def __str__(self):
- return "WHEN %r THEN %r" % (self.condition, self.result)
-
- def __repr__(self):
- return "<%s: %s>" % (self.__class__.__name__, self)
-
- def get_source_expressions(self):
- return [self.condition, self.result]
-
- def set_source_expressions(self, exprs):
- self.condition, self.result = exprs
-
- def get_source_fields(self):
- # We're only interested in the fields of the result expressions.
- return [self.result._output_field_or_none]
-
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
- c = self.copy()
- c.is_summary = summarize
- if hasattr(c.condition, 'resolve_expression'):
- c.condition = c.condition.resolve_expression(query, allow_joins, reuse, summarize, False)
- c.result = c.result.resolve_expression(query, allow_joins, reuse, summarize, for_save)
- return c
-
- def as_sql(self, compiler, connection, template=None, **extra_context):
- connection.ops.check_expression_support(self)
- template_params = extra_context
- sql_params = []
- condition_sql, condition_params = compiler.compile(self.condition)
- template_params['condition'] = condition_sql
- sql_params.extend(condition_params)
- result_sql, result_params = compiler.compile(self.result)
- template_params['result'] = result_sql
- sql_params.extend(result_params)
- template = template or self.template
- return template % template_params, sql_params
-
- def get_group_by_cols(self):
- # This is not a complete expression and cannot be used in GROUP BY.
- cols = []
- for source in self.get_source_expressions():
- cols.extend(source.get_group_by_cols())
- return cols
-
-
- class Case(Expression):
- """
- An SQL searched CASE expression:
-
- CASE
- WHEN n > 0
- THEN 'positive'
- WHEN n < 0
- THEN 'negative'
- ELSE 'zero'
- END
- """
- template = 'CASE %(cases)s ELSE %(default)s END'
- case_joiner = ' '
-
- def __init__(self, *cases, default=None, output_field=None, **extra):
- if not all(isinstance(case, When) for case in cases):
- raise TypeError("Positional arguments must all be When objects.")
- super().__init__(output_field)
- self.cases = list(cases)
- self.default = self._parse_expressions(default)[0]
- self.extra = extra
-
- def __str__(self):
- return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default)
-
- def __repr__(self):
- return "<%s: %s>" % (self.__class__.__name__, self)
-
- def get_source_expressions(self):
- return self.cases + [self.default]
-
- def set_source_expressions(self, exprs):
- *self.cases, self.default = exprs
-
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
- c = self.copy()
- c.is_summary = summarize
- for pos, case in enumerate(c.cases):
- c.cases[pos] = case.resolve_expression(query, allow_joins, reuse, summarize, for_save)
- c.default = c.default.resolve_expression(query, allow_joins, reuse, summarize, for_save)
- return c
-
- def copy(self):
- c = super().copy()
- c.cases = c.cases[:]
- return c
-
- def as_sql(self, compiler, connection, template=None, case_joiner=None, **extra_context):
- connection.ops.check_expression_support(self)
- if not self.cases:
- return compiler.compile(self.default)
- template_params = {**self.extra, **extra_context}
- case_parts = []
- sql_params = []
- for case in self.cases:
- try:
- case_sql, case_params = compiler.compile(case)
- except EmptyResultSet:
- continue
- case_parts.append(case_sql)
- sql_params.extend(case_params)
- default_sql, default_params = compiler.compile(self.default)
- if not case_parts:
- return default_sql, default_params
- case_joiner = case_joiner or self.case_joiner
- template_params['cases'] = case_joiner.join(case_parts)
- template_params['default'] = default_sql
- sql_params.extend(default_params)
- template = template or template_params.get('template', self.template)
- sql = template % template_params
- if self._output_field_or_none is not None:
- sql = connection.ops.unification_cast_sql(self.output_field) % sql
- return sql, sql_params
-
-
- class Subquery(Expression):
- """
- An explicit subquery. It may contain OuterRef() references to the outer
- query which will be resolved when it is applied to that query.
- """
- template = '(%(subquery)s)'
- contains_aggregate = False
-
- def __init__(self, queryset, output_field=None, **extra):
- self.queryset = queryset
- self.extra = extra
- super().__init__(output_field)
-
- def _resolve_output_field(self):
- if len(self.queryset.query.select) == 1:
- return self.queryset.query.select[0].field
- return super()._resolve_output_field()
-
- def copy(self):
- clone = super().copy()
- clone.queryset = clone.queryset.all()
- return clone
-
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
- clone = self.copy()
- clone.is_summary = summarize
- clone.queryset.query.bump_prefix(query)
-
- # Need to recursively resolve these.
- def resolve_all(child):
- if hasattr(child, 'children'):
- [resolve_all(_child) for _child in child.children]
- if hasattr(child, 'rhs'):
- child.rhs = resolve(child.rhs)
-
- def resolve(child):
- if hasattr(child, 'resolve_expression'):
- resolved = child.resolve_expression(
- query=query, allow_joins=allow_joins, reuse=reuse,
- summarize=summarize, for_save=for_save,
- )
- # Add table alias to the parent query's aliases to prevent
- # quoting.
- if hasattr(resolved, 'alias') and resolved.alias != resolved.target.model._meta.db_table:
- clone.queryset.query.external_aliases.add(resolved.alias)
- return resolved
- return child
-
- resolve_all(clone.queryset.query.where)
-
- for key, value in clone.queryset.query.annotations.items():
- if isinstance(value, Subquery):
- clone.queryset.query.annotations[key] = resolve(value)
-
- return clone
-
- def get_source_expressions(self):
- return [
- x for x in [
- getattr(expr, 'lhs', None)
- for expr in self.queryset.query.where.children
- ] if x
- ]
-
- def relabeled_clone(self, change_map):
- clone = self.copy()
- clone.queryset.query = clone.queryset.query.relabeled_clone(change_map)
- clone.queryset.query.external_aliases.update(
- alias for alias in change_map.values()
- if alias not in clone.queryset.query.alias_map
- )
- return clone
-
- def as_sql(self, compiler, connection, template=None, **extra_context):
- connection.ops.check_expression_support(self)
- template_params = {**self.extra, **extra_context}
- template_params['subquery'], sql_params = self.queryset.query.get_compiler(connection=connection).as_sql()
-
- template = template or template_params.get('template', self.template)
- sql = template % template_params
- return sql, sql_params
-
- def _prepare(self, output_field):
- # This method will only be called if this instance is the "rhs" in an
- # expression: the wrapping () must be removed (as the expression that
- # contains this will provide them). SQLite evaluates ((subquery))
- # differently than the other databases.
- if self.template == '(%(subquery)s)':
- clone = self.copy()
- clone.template = '%(subquery)s'
- return clone
- return self
-
-
- class Exists(Subquery):
- template = 'EXISTS(%(subquery)s)'
- output_field = fields.BooleanField()
-
- def __init__(self, *args, negated=False, **kwargs):
- self.negated = negated
- super().__init__(*args, **kwargs)
-
- def __invert__(self):
- return type(self)(self.queryset, negated=(not self.negated), **self.extra)
-
- def resolve_expression(self, query=None, *args, **kwargs):
- # As a performance optimization, remove ordering since EXISTS doesn't
- # care about it, just whether or not a row matches.
- self.queryset = self.queryset.order_by()
- return super().resolve_expression(query, *args, **kwargs)
-
- def as_sql(self, compiler, connection, template=None, **extra_context):
- sql, params = super().as_sql(compiler, connection, template, **extra_context)
- if self.negated:
- sql = 'NOT {}'.format(sql)
- return sql, params
-
- def as_oracle(self, compiler, connection, template=None, **extra_context):
- # Oracle doesn't allow EXISTS() in the SELECT list, so wrap it with a
- # CASE WHEN expression. Change the template since the When expression
- # requires a left hand side (column) to compare against.
- sql, params = self.as_sql(compiler, connection, template, **extra_context)
- sql = 'CASE WHEN {} THEN 1 ELSE 0 END'.format(sql)
- return sql, params
-
-
- class OrderBy(BaseExpression):
- template = '%(expression)s %(ordering)s'
-
- def __init__(self, expression, descending=False, nulls_first=False, nulls_last=False):
- if nulls_first and nulls_last:
- raise ValueError('nulls_first and nulls_last are mutually exclusive')
- self.nulls_first = nulls_first
- self.nulls_last = nulls_last
- self.descending = descending
- if not hasattr(expression, 'resolve_expression'):
- raise ValueError('expression must be an expression type')
- self.expression = expression
-
- def __repr__(self):
- return "{}({}, descending={})".format(
- self.__class__.__name__, self.expression, self.descending)
-
- def set_source_expressions(self, exprs):
- self.expression = exprs[0]
-
- def get_source_expressions(self):
- return [self.expression]
-
- def as_sql(self, compiler, connection, template=None, **extra_context):
- if not template:
- if self.nulls_last:
- template = '%s NULLS LAST' % self.template
- elif self.nulls_first:
- template = '%s NULLS FIRST' % self.template
- connection.ops.check_expression_support(self)
- expression_sql, params = compiler.compile(self.expression)
- placeholders = {
- 'expression': expression_sql,
- 'ordering': 'DESC' if self.descending else 'ASC',
- **extra_context,
- }
- template = template or self.template
- params *= template.count('%(expression)s')
- return (template % placeholders).rstrip(), params
-
- def as_sqlite(self, compiler, connection):
- template = None
- if self.nulls_last:
- template = '%(expression)s IS NULL, %(expression)s %(ordering)s'
- elif self.nulls_first:
- template = '%(expression)s IS NOT NULL, %(expression)s %(ordering)s'
- return self.as_sql(compiler, connection, template=template)
-
- def as_mysql(self, compiler, connection):
- template = None
- if self.nulls_last:
- template = 'IF(ISNULL(%(expression)s),1,0), %(expression)s %(ordering)s '
- elif self.nulls_first:
- template = 'IF(ISNULL(%(expression)s),0,1), %(expression)s %(ordering)s '
- return self.as_sql(compiler, connection, template=template)
-
- def get_group_by_cols(self):
- cols = []
- for source in self.get_source_expressions():
- cols.extend(source.get_group_by_cols())
- return cols
-
- def reverse_ordering(self):
- self.descending = not self.descending
- if self.nulls_first or self.nulls_last:
- self.nulls_first = not self.nulls_first
- self.nulls_last = not self.nulls_last
- return self
-
- def asc(self):
- self.descending = False
-
- def desc(self):
- self.descending = True
-
-
- class Window(Expression):
- template = '%(expression)s OVER (%(window)s)'
- # Although the main expression may either be an aggregate or an
- # expression with an aggregate function, the GROUP BY that will
- # be introduced in the query as a result is not desired.
- contains_aggregate = False
- contains_over_clause = True
- filterable = False
-
- def __init__(self, expression, partition_by=None, order_by=None, frame=None, output_field=None):
- self.partition_by = partition_by
- self.order_by = order_by
- self.frame = frame
-
- if not getattr(expression, 'window_compatible', False):
- raise ValueError(
- "Expression '%s' isn't compatible with OVER clauses." %
- expression.__class__.__name__
- )
-
- if self.partition_by is not None:
- if not isinstance(self.partition_by, (tuple, list)):
- self.partition_by = (self.partition_by,)
- self.partition_by = ExpressionList(*self.partition_by)
-
- if self.order_by is not None:
- if isinstance(self.order_by, (list, tuple)):
- self.order_by = ExpressionList(*self.order_by)
- elif not isinstance(self.order_by, BaseExpression):
- raise ValueError(
- 'order_by must be either an Expression or a sequence of '
- 'expressions.'
- )
- super().__init__(output_field=output_field)
- self.source_expression = self._parse_expressions(expression)[0]
-
- def _resolve_output_field(self):
- return self.source_expression.output_field
-
- def get_source_expressions(self):
- return [self.source_expression, self.partition_by, self.order_by, self.frame]
-
- def set_source_expressions(self, exprs):
- self.source_expression, self.partition_by, self.order_by, self.frame = exprs
-
- def as_sql(self, compiler, connection, template=None):
- connection.ops.check_expression_support(self)
- expr_sql, params = compiler.compile(self.source_expression)
- window_sql, window_params = [], []
-
- if self.partition_by is not None:
- sql_expr, sql_params = self.partition_by.as_sql(
- compiler=compiler, connection=connection,
- template='PARTITION BY %(expressions)s',
- )
- window_sql.extend(sql_expr)
- window_params.extend(sql_params)
-
- if self.order_by is not None:
- window_sql.append(' ORDER BY ')
- order_sql, order_params = compiler.compile(self.order_by)
- window_sql.extend(''.join(order_sql))
- window_params.extend(order_params)
-
- if self.frame:
- frame_sql, frame_params = compiler.compile(self.frame)
- window_sql.extend(' ' + frame_sql)
- window_params.extend(frame_params)
-
- params.extend(window_params)
- template = template or self.template
-
- return template % {
- 'expression': expr_sql,
- 'window': ''.join(window_sql).strip()
- }, params
-
- def __str__(self):
- return '{} OVER ({}{}{})'.format(
- str(self.source_expression),
- 'PARTITION BY ' + str(self.partition_by) if self.partition_by else '',
- 'ORDER BY ' + str(self.order_by) if self.order_by else '',
- str(self.frame or ''),
- )
-
- def __repr__(self):
- return '<%s: %s>' % (self.__class__.__name__, self)
-
- def get_group_by_cols(self):
- return []
-
-
- class WindowFrame(Expression):
- """
- Model the frame clause in window expressions. There are two types of frame
- clauses which are subclasses, however, all processing and validation (by no
- means intended to be complete) is done here. Thus, providing an end for a
- frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
- row in the frame).
- """
- template = '%(frame_type)s BETWEEN %(start)s AND %(end)s'
-
- def __init__(self, start=None, end=None):
- self.start = Value(start)
- self.end = Value(end)
-
- def set_source_expressions(self, exprs):
- self.start, self.end = exprs
-
- def get_source_expressions(self):
- return [self.start, self.end]
-
- def as_sql(self, compiler, connection):
- connection.ops.check_expression_support(self)
- start, end = self.window_frame_start_end(connection, self.start.value, self.end.value)
- return self.template % {
- 'frame_type': self.frame_type,
- 'start': start,
- 'end': end,
- }, []
-
- def __repr__(self):
- return '<%s: %s>' % (self.__class__.__name__, self)
-
- def get_group_by_cols(self):
- return []
-
- def __str__(self):
- if self.start.value is not None and self.start.value < 0:
- start = '%d %s' % (abs(self.start.value), connection.ops.PRECEDING)
- elif self.start.value is not None and self.start.value == 0:
- start = connection.ops.CURRENT_ROW
- else:
- start = connection.ops.UNBOUNDED_PRECEDING
-
- if self.end.value is not None and self.end.value > 0:
- end = '%d %s' % (self.end.value, connection.ops.FOLLOWING)
- elif self.end.value is not None and self.end.value == 0:
- end = connection.ops.CURRENT_ROW
- else:
- end = connection.ops.UNBOUNDED_FOLLOWING
- return self.template % {
- 'frame_type': self.frame_type,
- 'start': start,
- 'end': end,
- }
-
- def window_frame_start_end(self, connection, start, end):
- raise NotImplementedError('Subclasses must implement window_frame_start_end().')
-
-
- class RowRange(WindowFrame):
- frame_type = 'ROWS'
-
- def window_frame_start_end(self, connection, start, end):
- return connection.ops.window_frame_rows_start_end(start, end)
-
-
- class ValueRange(WindowFrame):
- frame_type = 'RANGE'
-
- def window_frame_start_end(self, connection, start, end):
- return connection.ops.window_frame_range_start_end(start, end)
|