123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556 |
- import itertools
- import math
- from copy import copy
-
- from django.core.exceptions import EmptyResultSet
- from django.db.models.expressions import Func, Value
- from django.db.models.fields import DateTimeField, Field, IntegerField
- from django.db.models.query_utils import RegisterLookupMixin
- from django.utils.datastructures import OrderedSet
- from django.utils.functional import cached_property
-
-
- class Lookup:
- lookup_name = None
- prepare_rhs = True
- can_use_none_as_rhs = False
-
- def __init__(self, lhs, rhs):
- self.lhs, self.rhs = lhs, rhs
- self.rhs = self.get_prep_lookup()
- if hasattr(self.lhs, 'get_bilateral_transforms'):
- bilateral_transforms = self.lhs.get_bilateral_transforms()
- else:
- bilateral_transforms = []
- if bilateral_transforms:
- # Warn the user as soon as possible if they are trying to apply
- # a bilateral transformation on a nested QuerySet: that won't work.
- from django.db.models.sql.query import Query # avoid circular import
- if isinstance(rhs, Query):
- raise NotImplementedError("Bilateral transformations on nested querysets are not implemented.")
- self.bilateral_transforms = bilateral_transforms
-
- def apply_bilateral_transforms(self, value):
- for transform in self.bilateral_transforms:
- value = transform(value)
- return value
-
- def batch_process_rhs(self, compiler, connection, rhs=None):
- if rhs is None:
- rhs = self.rhs
- if self.bilateral_transforms:
- sqls, sqls_params = [], []
- for p in rhs:
- value = Value(p, output_field=self.lhs.output_field)
- value = self.apply_bilateral_transforms(value)
- value = value.resolve_expression(compiler.query)
- sql, sql_params = compiler.compile(value)
- sqls.append(sql)
- sqls_params.extend(sql_params)
- else:
- _, params = self.get_db_prep_lookup(rhs, connection)
- sqls, sqls_params = ['%s'] * len(params), params
- return sqls, sqls_params
-
- def get_source_expressions(self):
- if self.rhs_is_direct_value():
- return [self.lhs]
- return [self.lhs, self.rhs]
-
- def set_source_expressions(self, new_exprs):
- if len(new_exprs) == 1:
- self.lhs = new_exprs[0]
- else:
- self.lhs, self.rhs = new_exprs
-
- def get_prep_lookup(self):
- if hasattr(self.rhs, '_prepare'):
- return self.rhs._prepare(self.lhs.output_field)
- if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
- return self.lhs.output_field.get_prep_value(self.rhs)
- return self.rhs
-
- def get_db_prep_lookup(self, value, connection):
- return ('%s', [value])
-
- def process_lhs(self, compiler, connection, lhs=None):
- lhs = lhs or self.lhs
- if hasattr(lhs, 'resolve_expression'):
- lhs = lhs.resolve_expression(compiler.query)
- return compiler.compile(lhs)
-
- def process_rhs(self, compiler, connection):
- value = self.rhs
- if self.bilateral_transforms:
- if self.rhs_is_direct_value():
- # Do not call get_db_prep_lookup here as the value will be
- # transformed before being used for lookup
- value = Value(value, output_field=self.lhs.output_field)
- value = self.apply_bilateral_transforms(value)
- value = value.resolve_expression(compiler.query)
- if hasattr(value, 'as_sql'):
- sql, params = compiler.compile(value)
- return '(' + sql + ')', params
- else:
- return self.get_db_prep_lookup(value, connection)
-
- def rhs_is_direct_value(self):
- return not hasattr(self.rhs, 'as_sql')
-
- def relabeled_clone(self, relabels):
- new = copy(self)
- new.lhs = new.lhs.relabeled_clone(relabels)
- if hasattr(new.rhs, 'relabeled_clone'):
- new.rhs = new.rhs.relabeled_clone(relabels)
- return new
-
- def get_group_by_cols(self):
- cols = self.lhs.get_group_by_cols()
- if hasattr(self.rhs, 'get_group_by_cols'):
- cols.extend(self.rhs.get_group_by_cols())
- return cols
-
- def as_sql(self, compiler, connection):
- raise NotImplementedError
-
- @cached_property
- def contains_aggregate(self):
- return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
-
- @cached_property
- def contains_over_clause(self):
- return self.lhs.contains_over_clause or getattr(self.rhs, 'contains_over_clause', False)
-
- @property
- def is_summary(self):
- return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False)
-
-
- class Transform(RegisterLookupMixin, Func):
- """
- RegisterLookupMixin() is first so that get_lookup() and get_transform()
- first examine self and then check output_field.
- """
- bilateral = False
- arity = 1
-
- @property
- def lhs(self):
- return self.get_source_expressions()[0]
-
- def get_bilateral_transforms(self):
- if hasattr(self.lhs, 'get_bilateral_transforms'):
- bilateral_transforms = self.lhs.get_bilateral_transforms()
- else:
- bilateral_transforms = []
- if self.bilateral:
- bilateral_transforms.append(self.__class__)
- return bilateral_transforms
-
-
- class BuiltinLookup(Lookup):
- def process_lhs(self, compiler, connection, lhs=None):
- lhs_sql, params = super().process_lhs(compiler, connection, lhs)
- field_internal_type = self.lhs.output_field.get_internal_type()
- db_type = self.lhs.output_field.db_type(connection=connection)
- lhs_sql = connection.ops.field_cast_sql(
- db_type, field_internal_type) % lhs_sql
- lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
- return lhs_sql, list(params)
-
- def as_sql(self, compiler, connection):
- lhs_sql, params = self.process_lhs(compiler, connection)
- rhs_sql, rhs_params = self.process_rhs(compiler, connection)
- params.extend(rhs_params)
- rhs_sql = self.get_rhs_op(connection, rhs_sql)
- return '%s %s' % (lhs_sql, rhs_sql), params
-
- def get_rhs_op(self, connection, rhs):
- return connection.operators[self.lookup_name] % rhs
-
-
- class FieldGetDbPrepValueMixin:
- """
- Some lookups require Field.get_db_prep_value() to be called on their
- inputs.
- """
- get_db_prep_lookup_value_is_iterable = False
-
- def get_db_prep_lookup(self, value, connection):
- # For relational fields, use the output_field of the 'field' attribute.
- field = getattr(self.lhs.output_field, 'field', None)
- get_db_prep_value = getattr(field, 'get_db_prep_value', None) or self.lhs.output_field.get_db_prep_value
- return (
- '%s',
- [get_db_prep_value(v, connection, prepared=True) for v in value]
- if self.get_db_prep_lookup_value_is_iterable else
- [get_db_prep_value(value, connection, prepared=True)]
- )
-
-
- class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
- """
- Some lookups require Field.get_db_prep_value() to be called on each value
- in an iterable.
- """
- get_db_prep_lookup_value_is_iterable = True
-
- def get_prep_lookup(self):
- prepared_values = []
- if hasattr(self.rhs, '_prepare'):
- # A subquery is like an iterable but its items shouldn't be
- # prepared independently.
- return self.rhs._prepare(self.lhs.output_field)
- for rhs_value in self.rhs:
- if hasattr(rhs_value, 'resolve_expression'):
- # An expression will be handled by the database but can coexist
- # alongside real values.
- pass
- elif self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
- rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
- prepared_values.append(rhs_value)
- return prepared_values
-
- def process_rhs(self, compiler, connection):
- if self.rhs_is_direct_value():
- # rhs should be an iterable of values. Use batch_process_rhs()
- # to prepare/transform those values.
- return self.batch_process_rhs(compiler, connection)
- else:
- return super().process_rhs(compiler, connection)
-
- def resolve_expression_parameter(self, compiler, connection, sql, param):
- params = [param]
- if hasattr(param, 'resolve_expression'):
- param = param.resolve_expression(compiler.query)
- if hasattr(param, 'as_sql'):
- sql, params = param.as_sql(compiler, connection)
- return sql, params
-
- def batch_process_rhs(self, compiler, connection, rhs=None):
- pre_processed = super().batch_process_rhs(compiler, connection, rhs)
- # The params list may contain expressions which compile to a
- # sql/param pair. Zip them to get sql and param pairs that refer to the
- # same argument and attempt to replace them with the result of
- # compiling the param step.
- sql, params = zip(*(
- self.resolve_expression_parameter(compiler, connection, sql, param)
- for sql, param in zip(*pre_processed)
- ))
- params = itertools.chain.from_iterable(params)
- return sql, tuple(params)
-
-
- @Field.register_lookup
- class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
- lookup_name = 'exact'
-
- def process_rhs(self, compiler, connection):
- from django.db.models.sql.query import Query
- if isinstance(self.rhs, Query):
- if self.rhs.has_limit_one():
- # The subquery must select only the pk.
- self.rhs.clear_select_clause()
- self.rhs.add_fields(['pk'])
- else:
- raise ValueError(
- 'The QuerySet value for an exact lookup must be limited to '
- 'one result using slicing.'
- )
- return super().process_rhs(compiler, connection)
-
-
- @Field.register_lookup
- class IExact(BuiltinLookup):
- lookup_name = 'iexact'
- prepare_rhs = False
-
- def process_rhs(self, qn, connection):
- rhs, params = super().process_rhs(qn, connection)
- if params:
- params[0] = connection.ops.prep_for_iexact_query(params[0])
- return rhs, params
-
-
- @Field.register_lookup
- class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
- lookup_name = 'gt'
-
-
- @Field.register_lookup
- class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
- lookup_name = 'gte'
-
-
- @Field.register_lookup
- class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
- lookup_name = 'lt'
-
-
- @Field.register_lookup
- class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
- lookup_name = 'lte'
-
-
- class IntegerFieldFloatRounding:
- """
- Allow floats to work as query values for IntegerField. Without this, the
- decimal portion of the float would always be discarded.
- """
- def get_prep_lookup(self):
- if isinstance(self.rhs, float):
- self.rhs = math.ceil(self.rhs)
- return super().get_prep_lookup()
-
-
- @IntegerField.register_lookup
- class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual):
- pass
-
-
- @IntegerField.register_lookup
- class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
- pass
-
-
- @Field.register_lookup
- class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
- lookup_name = 'in'
-
- def process_rhs(self, compiler, connection):
- db_rhs = getattr(self.rhs, '_db', None)
- if db_rhs is not None and db_rhs != connection.alias:
- raise ValueError(
- "Subqueries aren't allowed across different databases. Force "
- "the inner query to be evaluated using `list(inner_query)`."
- )
-
- if self.rhs_is_direct_value():
- try:
- rhs = OrderedSet(self.rhs)
- except TypeError: # Unhashable items in self.rhs
- rhs = self.rhs
-
- if not rhs:
- raise EmptyResultSet
-
- # rhs should be an iterable; use batch_process_rhs() to
- # prepare/transform those values.
- sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
- placeholder = '(' + ', '.join(sqls) + ')'
- return (placeholder, sqls_params)
- else:
- if not getattr(self.rhs, 'has_select_fields', True):
- self.rhs.clear_select_clause()
- self.rhs.add_fields(['pk'])
- return super().process_rhs(compiler, connection)
-
- def get_rhs_op(self, connection, rhs):
- return 'IN %s' % rhs
-
- def as_sql(self, compiler, connection):
- max_in_list_size = connection.ops.max_in_list_size()
- if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size:
- return self.split_parameter_list_as_sql(compiler, connection)
- return super().as_sql(compiler, connection)
-
- def split_parameter_list_as_sql(self, compiler, connection):
- # This is a special case for databases which limit the number of
- # elements which can appear in an 'IN' clause.
- max_in_list_size = connection.ops.max_in_list_size()
- lhs, lhs_params = self.process_lhs(compiler, connection)
- rhs, rhs_params = self.batch_process_rhs(compiler, connection)
- in_clause_elements = ['(']
- params = []
- for offset in range(0, len(rhs_params), max_in_list_size):
- if offset > 0:
- in_clause_elements.append(' OR ')
- in_clause_elements.append('%s IN (' % lhs)
- params.extend(lhs_params)
- sqls = rhs[offset: offset + max_in_list_size]
- sqls_params = rhs_params[offset: offset + max_in_list_size]
- param_group = ', '.join(sqls)
- in_clause_elements.append(param_group)
- in_clause_elements.append(')')
- params.extend(sqls_params)
- in_clause_elements.append(')')
- return ''.join(in_clause_elements), params
-
-
- class PatternLookup(BuiltinLookup):
- param_pattern = '%%%s%%'
- prepare_rhs = False
-
- def get_rhs_op(self, connection, rhs):
- # Assume we are in startswith. We need to produce SQL like:
- # col LIKE %s, ['thevalue%']
- # For python values we can (and should) do that directly in Python,
- # but if the value is for example reference to other column, then
- # we need to add the % pattern match to the lookup by something like
- # col LIKE othercol || '%%'
- # So, for Python values we don't need any special pattern, but for
- # SQL reference values or SQL transformations we need the correct
- # pattern added.
- if hasattr(self.rhs, 'as_sql') or self.bilateral_transforms:
- pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc)
- return pattern.format(rhs)
- else:
- return super().get_rhs_op(connection, rhs)
-
- def process_rhs(self, qn, connection):
- rhs, params = super().process_rhs(qn, connection)
- if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
- params[0] = self.param_pattern % connection.ops.prep_for_like_query(params[0])
- return rhs, params
-
-
- @Field.register_lookup
- class Contains(PatternLookup):
- lookup_name = 'contains'
-
-
- @Field.register_lookup
- class IContains(Contains):
- lookup_name = 'icontains'
-
-
- @Field.register_lookup
- class StartsWith(PatternLookup):
- lookup_name = 'startswith'
- param_pattern = '%s%%'
-
-
- @Field.register_lookup
- class IStartsWith(StartsWith):
- lookup_name = 'istartswith'
-
-
- @Field.register_lookup
- class EndsWith(PatternLookup):
- lookup_name = 'endswith'
- param_pattern = '%%%s'
-
-
- @Field.register_lookup
- class IEndsWith(EndsWith):
- lookup_name = 'iendswith'
-
-
- @Field.register_lookup
- class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
- lookup_name = 'range'
-
- def get_rhs_op(self, connection, rhs):
- return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
-
-
- @Field.register_lookup
- class IsNull(BuiltinLookup):
- lookup_name = 'isnull'
- prepare_rhs = False
-
- def as_sql(self, compiler, connection):
- sql, params = compiler.compile(self.lhs)
- if self.rhs:
- return "%s IS NULL" % sql, params
- else:
- return "%s IS NOT NULL" % sql, params
-
-
- @Field.register_lookup
- class Regex(BuiltinLookup):
- lookup_name = 'regex'
- prepare_rhs = False
-
- def as_sql(self, compiler, connection):
- if self.lookup_name in connection.operators:
- return super().as_sql(compiler, connection)
- else:
- lhs, lhs_params = self.process_lhs(compiler, connection)
- rhs, rhs_params = self.process_rhs(compiler, connection)
- sql_template = connection.ops.regex_lookup(self.lookup_name)
- return sql_template % (lhs, rhs), lhs_params + rhs_params
-
-
- @Field.register_lookup
- class IRegex(Regex):
- lookup_name = 'iregex'
-
-
- class YearLookup(Lookup):
- def year_lookup_bounds(self, connection, year):
- output_field = self.lhs.lhs.output_field
- if isinstance(output_field, DateTimeField):
- bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
- else:
- bounds = connection.ops.year_lookup_bounds_for_date_field(year)
- return bounds
-
-
- class YearComparisonLookup(YearLookup):
- def as_sql(self, compiler, connection):
- # We will need to skip the extract part and instead go
- # directly with the originating field, that is self.lhs.lhs.
- lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
- rhs_sql, rhs_params = self.process_rhs(compiler, connection)
- rhs_sql = self.get_rhs_op(connection, rhs_sql)
- start, finish = self.year_lookup_bounds(connection, rhs_params[0])
- params.append(self.get_bound(start, finish))
- return '%s %s' % (lhs_sql, rhs_sql), params
-
- def get_rhs_op(self, connection, rhs):
- return connection.operators[self.lookup_name] % rhs
-
- def get_bound(self, start, finish):
- raise NotImplementedError(
- 'subclasses of YearComparisonLookup must provide a get_bound() method'
- )
-
-
- class YearExact(YearLookup, Exact):
- lookup_name = 'exact'
-
- def as_sql(self, compiler, connection):
- # We will need to skip the extract part and instead go
- # directly with the originating field, that is self.lhs.lhs.
- lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
- rhs_sql, rhs_params = self.process_rhs(compiler, connection)
- try:
- # Check that rhs_params[0] exists (IndexError),
- # it isn't None (TypeError), and is a number (ValueError)
- int(rhs_params[0])
- except (IndexError, TypeError, ValueError):
- # Can't determine the bounds before executing the query, so skip
- # optimizations by falling back to a standard exact comparison.
- return super().as_sql(compiler, connection)
- bounds = self.year_lookup_bounds(connection, rhs_params[0])
- params.extend(bounds)
- return '%s BETWEEN %%s AND %%s' % lhs_sql, params
-
-
- class YearGt(YearComparisonLookup):
- lookup_name = 'gt'
-
- def get_bound(self, start, finish):
- return finish
-
-
- class YearGte(YearComparisonLookup):
- lookup_name = 'gte'
-
- def get_bound(self, start, finish):
- return start
-
-
- class YearLt(YearComparisonLookup):
- lookup_name = 'lt'
-
- def get_bound(self, start, finish):
- return start
-
-
- class YearLte(YearComparisonLookup):
- lookup_name = 'lte'
-
- def get_bound(self, start, finish):
- return finish
|