|
|
- import itertools
- import math
-
- from django.core.exceptions import EmptyResultSet
- from django.db.models.expressions import Case, Expression, Func, Value, When
- from django.db.models.fields import (
- BooleanField,
- CharField,
- DateTimeField,
- Field,
- IntegerField,
- UUIDField,
- )
- from django.db.models.query_utils import RegisterLookupMixin
- from django.utils.datastructures import OrderedSet
- from django.utils.functional import cached_property
- from django.utils.hashable import make_hashable
-
-
- class Lookup(Expression):
- lookup_name = None
- prepare_rhs = True
- can_use_none_as_rhs = False
-
- def __init__(self, lhs, rhs):
- self.lhs, self.rhs = lhs, rhs
- self.rhs = self.get_prep_lookup()
- self.lhs = self.get_prep_lhs()
- if hasattr(self.lhs, "get_bilateral_transforms"):
- bilateral_transforms = self.lhs.get_bilateral_transforms()
- else:
- bilateral_transforms = []
- if bilateral_transforms:
- # Warn the user as soon as possible if they are trying to apply
- # a bilateral transformation on a nested QuerySet: that won't work.
- from django.db.models.sql.query import Query # avoid circular import
-
- if isinstance(rhs, Query):
- raise NotImplementedError(
- "Bilateral transformations on nested querysets are not implemented."
- )
- self.bilateral_transforms = bilateral_transforms
-
- def apply_bilateral_transforms(self, value):
- for transform in self.bilateral_transforms:
- value = transform(value)
- return value
-
- def __repr__(self):
- return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})"
-
- def batch_process_rhs(self, compiler, connection, rhs=None):
- if rhs is None:
- rhs = self.rhs
- if self.bilateral_transforms:
- sqls, sqls_params = [], []
- for p in rhs:
- value = Value(p, output_field=self.lhs.output_field)
- value = self.apply_bilateral_transforms(value)
- value = value.resolve_expression(compiler.query)
- sql, sql_params = compiler.compile(value)
- sqls.append(sql)
- sqls_params.extend(sql_params)
- else:
- _, params = self.get_db_prep_lookup(rhs, connection)
- sqls, sqls_params = ["%s"] * len(params), params
- return sqls, sqls_params
-
- def get_source_expressions(self):
- if self.rhs_is_direct_value():
- return [self.lhs]
- return [self.lhs, self.rhs]
-
- def set_source_expressions(self, new_exprs):
- if len(new_exprs) == 1:
- self.lhs = new_exprs[0]
- else:
- self.lhs, self.rhs = new_exprs
-
- def get_prep_lookup(self):
- if not self.prepare_rhs or hasattr(self.rhs, "resolve_expression"):
- return self.rhs
- if hasattr(self.lhs, "output_field"):
- if hasattr(self.lhs.output_field, "get_prep_value"):
- return self.lhs.output_field.get_prep_value(self.rhs)
- elif self.rhs_is_direct_value():
- return Value(self.rhs)
- return self.rhs
-
- def get_prep_lhs(self):
- if hasattr(self.lhs, "resolve_expression"):
- return self.lhs
- return Value(self.lhs)
-
- def get_db_prep_lookup(self, value, connection):
- return ("%s", [value])
-
- def process_lhs(self, compiler, connection, lhs=None):
- lhs = lhs or self.lhs
- if hasattr(lhs, "resolve_expression"):
- lhs = lhs.resolve_expression(compiler.query)
- sql, params = compiler.compile(lhs)
- if isinstance(lhs, Lookup):
- # Wrapped in parentheses to respect operator precedence.
- sql = f"({sql})"
- return sql, params
-
- def process_rhs(self, compiler, connection):
- value = self.rhs
- if self.bilateral_transforms:
- if self.rhs_is_direct_value():
- # Do not call get_db_prep_lookup here as the value will be
- # transformed before being used for lookup
- value = Value(value, output_field=self.lhs.output_field)
- value = self.apply_bilateral_transforms(value)
- value = value.resolve_expression(compiler.query)
- if hasattr(value, "as_sql"):
- sql, params = compiler.compile(value)
- # Ensure expression is wrapped in parentheses to respect operator
- # precedence but avoid double wrapping as it can be misinterpreted
- # on some backends (e.g. subqueries on SQLite).
- if sql and sql[0] != "(":
- sql = "(%s)" % sql
- return sql, params
- else:
- return self.get_db_prep_lookup(value, connection)
-
- def rhs_is_direct_value(self):
- return not hasattr(self.rhs, "as_sql")
-
- def get_group_by_cols(self, alias=None):
- cols = []
- for source in self.get_source_expressions():
- cols.extend(source.get_group_by_cols())
- return cols
-
- def as_oracle(self, compiler, connection):
- # Oracle doesn't allow EXISTS() and filters to be compared to another
- # expression unless they're wrapped in a CASE WHEN.
- wrapped = False
- exprs = []
- for expr in (self.lhs, self.rhs):
- if connection.ops.conditional_expression_supported_in_where_clause(expr):
- expr = Case(When(expr, then=True), default=False)
- wrapped = True
- exprs.append(expr)
- lookup = type(self)(*exprs) if wrapped else self
- return lookup.as_sql(compiler, connection)
-
- @cached_property
- def output_field(self):
- return BooleanField()
-
- @property
- def identity(self):
- return self.__class__, self.lhs, self.rhs
-
- def __eq__(self, other):
- if not isinstance(other, Lookup):
- return NotImplemented
- return self.identity == other.identity
-
- def __hash__(self):
- return hash(make_hashable(self.identity))
-
- def resolve_expression(
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
- ):
- c = self.copy()
- c.is_summary = summarize
- c.lhs = self.lhs.resolve_expression(
- query, allow_joins, reuse, summarize, for_save
- )
- if hasattr(self.rhs, "resolve_expression"):
- c.rhs = self.rhs.resolve_expression(
- query, allow_joins, reuse, summarize, for_save
- )
- return c
-
- def select_format(self, compiler, sql, params):
- # Wrap filters with a CASE WHEN expression if a database backend
- # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
- # BY list.
- if not compiler.connection.features.supports_boolean_expr_in_select_clause:
- sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
- return sql, params
-
-
- class Transform(RegisterLookupMixin, Func):
- """
- RegisterLookupMixin() is first so that get_lookup() and get_transform()
- first examine self and then check output_field.
- """
-
- bilateral = False
- arity = 1
-
- @property
- def lhs(self):
- return self.get_source_expressions()[0]
-
- def get_bilateral_transforms(self):
- if hasattr(self.lhs, "get_bilateral_transforms"):
- bilateral_transforms = self.lhs.get_bilateral_transforms()
- else:
- bilateral_transforms = []
- if self.bilateral:
- bilateral_transforms.append(self.__class__)
- return bilateral_transforms
-
-
- class BuiltinLookup(Lookup):
- def process_lhs(self, compiler, connection, lhs=None):
- lhs_sql, params = super().process_lhs(compiler, connection, lhs)
- field_internal_type = self.lhs.output_field.get_internal_type()
- db_type = self.lhs.output_field.db_type(connection=connection)
- lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
- lhs_sql = (
- connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
- )
- return lhs_sql, list(params)
-
- def as_sql(self, compiler, connection):
- lhs_sql, params = self.process_lhs(compiler, connection)
- rhs_sql, rhs_params = self.process_rhs(compiler, connection)
- params.extend(rhs_params)
- rhs_sql = self.get_rhs_op(connection, rhs_sql)
- return "%s %s" % (lhs_sql, rhs_sql), params
-
- def get_rhs_op(self, connection, rhs):
- return connection.operators[self.lookup_name] % rhs
-
-
- class FieldGetDbPrepValueMixin:
- """
- Some lookups require Field.get_db_prep_value() to be called on their
- inputs.
- """
-
- get_db_prep_lookup_value_is_iterable = False
-
- def get_db_prep_lookup(self, value, connection):
- # For relational fields, use the 'target_field' attribute of the
- # output_field.
- field = getattr(self.lhs.output_field, "target_field", None)
- get_db_prep_value = (
- getattr(field, "get_db_prep_value", None)
- or self.lhs.output_field.get_db_prep_value
- )
- return (
- "%s",
- [get_db_prep_value(v, connection, prepared=True) for v in value]
- if self.get_db_prep_lookup_value_is_iterable
- else [get_db_prep_value(value, connection, prepared=True)],
- )
-
-
- class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
- """
- Some lookups require Field.get_db_prep_value() to be called on each value
- in an iterable.
- """
-
- get_db_prep_lookup_value_is_iterable = True
-
- def get_prep_lookup(self):
- if hasattr(self.rhs, "resolve_expression"):
- return self.rhs
- prepared_values = []
- for rhs_value in self.rhs:
- if hasattr(rhs_value, "resolve_expression"):
- # An expression will be handled by the database but can coexist
- # alongside real values.
- pass
- elif self.prepare_rhs and hasattr(self.lhs.output_field, "get_prep_value"):
- rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
- prepared_values.append(rhs_value)
- return prepared_values
-
- def process_rhs(self, compiler, connection):
- if self.rhs_is_direct_value():
- # rhs should be an iterable of values. Use batch_process_rhs()
- # to prepare/transform those values.
- return self.batch_process_rhs(compiler, connection)
- else:
- return super().process_rhs(compiler, connection)
-
- def resolve_expression_parameter(self, compiler, connection, sql, param):
- params = [param]
- if hasattr(param, "resolve_expression"):
- param = param.resolve_expression(compiler.query)
- if hasattr(param, "as_sql"):
- sql, params = compiler.compile(param)
- return sql, params
-
- def batch_process_rhs(self, compiler, connection, rhs=None):
- pre_processed = super().batch_process_rhs(compiler, connection, rhs)
- # The params list may contain expressions which compile to a
- # sql/param pair. Zip them to get sql and param pairs that refer to the
- # same argument and attempt to replace them with the result of
- # compiling the param step.
- sql, params = zip(
- *(
- self.resolve_expression_parameter(compiler, connection, sql, param)
- for sql, param in zip(*pre_processed)
- )
- )
- params = itertools.chain.from_iterable(params)
- return sql, tuple(params)
-
-
- class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup):
- """Lookup defined by operators on PostgreSQL."""
-
- postgres_operator = None
-
- def as_postgresql(self, compiler, connection):
- lhs, lhs_params = self.process_lhs(compiler, connection)
- rhs, rhs_params = self.process_rhs(compiler, connection)
- params = tuple(lhs_params) + tuple(rhs_params)
- return "%s %s %s" % (lhs, self.postgres_operator, rhs), params
-
-
- @Field.register_lookup
- class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
- lookup_name = "exact"
-
- def get_prep_lookup(self):
- from django.db.models.sql.query import Query # avoid circular import
-
- if isinstance(self.rhs, Query):
- if self.rhs.has_limit_one():
- if not self.rhs.has_select_fields:
- self.rhs.clear_select_clause()
- self.rhs.add_fields(["pk"])
- else:
- raise ValueError(
- "The QuerySet value for an exact lookup must be limited to "
- "one result using slicing."
- )
- return super().get_prep_lookup()
-
- def as_sql(self, compiler, connection):
- # Avoid comparison against direct rhs if lhs is a boolean value. That
- # turns "boolfield__exact=True" into "WHERE boolean_field" instead of
- # "WHERE boolean_field = True" when allowed.
- if (
- isinstance(self.rhs, bool)
- and getattr(self.lhs, "conditional", False)
- and connection.ops.conditional_expression_supported_in_where_clause(
- self.lhs
- )
- ):
- lhs_sql, params = self.process_lhs(compiler, connection)
- template = "%s" if self.rhs else "NOT %s"
- return template % lhs_sql, params
- return super().as_sql(compiler, connection)
-
-
- @Field.register_lookup
- class IExact(BuiltinLookup):
- lookup_name = "iexact"
- prepare_rhs = False
-
- def process_rhs(self, qn, connection):
- rhs, params = super().process_rhs(qn, connection)
- if params:
- params[0] = connection.ops.prep_for_iexact_query(params[0])
- return rhs, params
-
-
- @Field.register_lookup
- class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
- lookup_name = "gt"
-
-
- @Field.register_lookup
- class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
- lookup_name = "gte"
-
-
- @Field.register_lookup
- class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
- lookup_name = "lt"
-
-
- @Field.register_lookup
- class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
- lookup_name = "lte"
-
-
- class IntegerFieldFloatRounding:
- """
- Allow floats to work as query values for IntegerField. Without this, the
- decimal portion of the float would always be discarded.
- """
-
- def get_prep_lookup(self):
- if isinstance(self.rhs, float):
- self.rhs = math.ceil(self.rhs)
- return super().get_prep_lookup()
-
-
- @IntegerField.register_lookup
- class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual):
- pass
-
-
- @IntegerField.register_lookup
- class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
- pass
-
-
- @Field.register_lookup
- class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
- lookup_name = "in"
-
- def get_prep_lookup(self):
- from django.db.models.sql.query import Query # avoid circular import
-
- if isinstance(self.rhs, Query):
- self.rhs.clear_ordering(clear_default=True)
- if not self.rhs.has_select_fields:
- self.rhs.clear_select_clause()
- self.rhs.add_fields(["pk"])
- return super().get_prep_lookup()
-
- def process_rhs(self, compiler, connection):
- db_rhs = getattr(self.rhs, "_db", None)
- if db_rhs is not None and db_rhs != connection.alias:
- raise ValueError(
- "Subqueries aren't allowed across different databases. Force "
- "the inner query to be evaluated using `list(inner_query)`."
- )
-
- if self.rhs_is_direct_value():
- # Remove None from the list as NULL is never equal to anything.
- try:
- rhs = OrderedSet(self.rhs)
- rhs.discard(None)
- except TypeError: # Unhashable items in self.rhs
- rhs = [r for r in self.rhs if r is not None]
-
- if not rhs:
- raise EmptyResultSet
-
- # rhs should be an iterable; use batch_process_rhs() to
- # prepare/transform those values.
- sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
- placeholder = "(" + ", ".join(sqls) + ")"
- return (placeholder, sqls_params)
- return super().process_rhs(compiler, connection)
-
- def get_rhs_op(self, connection, rhs):
- return "IN %s" % rhs
-
- def as_sql(self, compiler, connection):
- max_in_list_size = connection.ops.max_in_list_size()
- if (
- self.rhs_is_direct_value()
- and max_in_list_size
- and len(self.rhs) > max_in_list_size
- ):
- return self.split_parameter_list_as_sql(compiler, connection)
- return super().as_sql(compiler, connection)
-
- def split_parameter_list_as_sql(self, compiler, connection):
- # This is a special case for databases which limit the number of
- # elements which can appear in an 'IN' clause.
- max_in_list_size = connection.ops.max_in_list_size()
- lhs, lhs_params = self.process_lhs(compiler, connection)
- rhs, rhs_params = self.batch_process_rhs(compiler, connection)
- in_clause_elements = ["("]
- params = []
- for offset in range(0, len(rhs_params), max_in_list_size):
- if offset > 0:
- in_clause_elements.append(" OR ")
- in_clause_elements.append("%s IN (" % lhs)
- params.extend(lhs_params)
- sqls = rhs[offset : offset + max_in_list_size]
- sqls_params = rhs_params[offset : offset + max_in_list_size]
- param_group = ", ".join(sqls)
- in_clause_elements.append(param_group)
- in_clause_elements.append(")")
- params.extend(sqls_params)
- in_clause_elements.append(")")
- return "".join(in_clause_elements), params
-
-
- class PatternLookup(BuiltinLookup):
- param_pattern = "%%%s%%"
- prepare_rhs = False
-
- def get_rhs_op(self, connection, rhs):
- # Assume we are in startswith. We need to produce SQL like:
- # col LIKE %s, ['thevalue%']
- # For python values we can (and should) do that directly in Python,
- # but if the value is for example reference to other column, then
- # we need to add the % pattern match to the lookup by something like
- # col LIKE othercol || '%%'
- # So, for Python values we don't need any special pattern, but for
- # SQL reference values or SQL transformations we need the correct
- # pattern added.
- if hasattr(self.rhs, "as_sql") or self.bilateral_transforms:
- pattern = connection.pattern_ops[self.lookup_name].format(
- connection.pattern_esc
- )
- return pattern.format(rhs)
- else:
- return super().get_rhs_op(connection, rhs)
-
- def process_rhs(self, qn, connection):
- rhs, params = super().process_rhs(qn, connection)
- if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
- params[0] = self.param_pattern % connection.ops.prep_for_like_query(
- params[0]
- )
- return rhs, params
-
-
- @Field.register_lookup
- class Contains(PatternLookup):
- lookup_name = "contains"
-
-
- @Field.register_lookup
- class IContains(Contains):
- lookup_name = "icontains"
-
-
- @Field.register_lookup
- class StartsWith(PatternLookup):
- lookup_name = "startswith"
- param_pattern = "%s%%"
-
-
- @Field.register_lookup
- class IStartsWith(StartsWith):
- lookup_name = "istartswith"
-
-
- @Field.register_lookup
- class EndsWith(PatternLookup):
- lookup_name = "endswith"
- param_pattern = "%%%s"
-
-
- @Field.register_lookup
- class IEndsWith(EndsWith):
- lookup_name = "iendswith"
-
-
- @Field.register_lookup
- class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
- lookup_name = "range"
-
- def get_rhs_op(self, connection, rhs):
- return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
-
-
- @Field.register_lookup
- class IsNull(BuiltinLookup):
- lookup_name = "isnull"
- prepare_rhs = False
-
- def as_sql(self, compiler, connection):
- if not isinstance(self.rhs, bool):
- raise ValueError(
- "The QuerySet value for an isnull lookup must be True or False."
- )
- sql, params = compiler.compile(self.lhs)
- if self.rhs:
- return "%s IS NULL" % sql, params
- else:
- return "%s IS NOT NULL" % sql, params
-
-
- @Field.register_lookup
- class Regex(BuiltinLookup):
- lookup_name = "regex"
- prepare_rhs = False
-
- def as_sql(self, compiler, connection):
- if self.lookup_name in connection.operators:
- return super().as_sql(compiler, connection)
- else:
- lhs, lhs_params = self.process_lhs(compiler, connection)
- rhs, rhs_params = self.process_rhs(compiler, connection)
- sql_template = connection.ops.regex_lookup(self.lookup_name)
- return sql_template % (lhs, rhs), lhs_params + rhs_params
-
-
- @Field.register_lookup
- class IRegex(Regex):
- lookup_name = "iregex"
-
-
- class YearLookup(Lookup):
- def year_lookup_bounds(self, connection, year):
- from django.db.models.functions import ExtractIsoYear
-
- iso_year = isinstance(self.lhs, ExtractIsoYear)
- output_field = self.lhs.lhs.output_field
- if isinstance(output_field, DateTimeField):
- bounds = connection.ops.year_lookup_bounds_for_datetime_field(
- year,
- iso_year=iso_year,
- )
- else:
- bounds = connection.ops.year_lookup_bounds_for_date_field(
- year,
- iso_year=iso_year,
- )
- return bounds
-
- def as_sql(self, compiler, connection):
- # Avoid the extract operation if the rhs is a direct value to allow
- # indexes to be used.
- if self.rhs_is_direct_value():
- # Skip the extract part by directly using the originating field,
- # that is self.lhs.lhs.
- lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
- rhs_sql, _ = self.process_rhs(compiler, connection)
- rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
- start, finish = self.year_lookup_bounds(connection, self.rhs)
- params.extend(self.get_bound_params(start, finish))
- return "%s %s" % (lhs_sql, rhs_sql), params
- return super().as_sql(compiler, connection)
-
- def get_direct_rhs_sql(self, connection, rhs):
- return connection.operators[self.lookup_name] % rhs
-
- def get_bound_params(self, start, finish):
- raise NotImplementedError(
- "subclasses of YearLookup must provide a get_bound_params() method"
- )
-
-
- class YearExact(YearLookup, Exact):
- def get_direct_rhs_sql(self, connection, rhs):
- return "BETWEEN %s AND %s"
-
- def get_bound_params(self, start, finish):
- return (start, finish)
-
-
- class YearGt(YearLookup, GreaterThan):
- def get_bound_params(self, start, finish):
- return (finish,)
-
-
- class YearGte(YearLookup, GreaterThanOrEqual):
- def get_bound_params(self, start, finish):
- return (start,)
-
-
- class YearLt(YearLookup, LessThan):
- def get_bound_params(self, start, finish):
- return (start,)
-
-
- class YearLte(YearLookup, LessThanOrEqual):
- def get_bound_params(self, start, finish):
- return (finish,)
-
-
- class UUIDTextMixin:
- """
- Strip hyphens from a value when filtering a UUIDField on backends without
- a native datatype for UUID.
- """
-
- def process_rhs(self, qn, connection):
- if not connection.features.has_native_uuid_field:
- from django.db.models.functions import Replace
-
- if self.rhs_is_direct_value():
- self.rhs = Value(self.rhs)
- self.rhs = Replace(
- self.rhs, Value("-"), Value(""), output_field=CharField()
- )
- rhs, params = super().process_rhs(qn, connection)
- return rhs, params
-
-
- @UUIDField.register_lookup
- class UUIDIExact(UUIDTextMixin, IExact):
- pass
-
-
- @UUIDField.register_lookup
- class UUIDContains(UUIDTextMixin, Contains):
- pass
-
-
- @UUIDField.register_lookup
- class UUIDIContains(UUIDTextMixin, IContains):
- pass
-
-
- @UUIDField.register_lookup
- class UUIDStartsWith(UUIDTextMixin, StartsWith):
- pass
-
-
- @UUIDField.register_lookup
- class UUIDIStartsWith(UUIDTextMixin, IStartsWith):
- pass
-
-
- @UUIDField.register_lookup
- class UUIDEndsWith(UUIDTextMixin, EndsWith):
- pass
-
-
- @UUIDField.register_lookup
- class UUIDIEndsWith(UUIDTextMixin, IEndsWith):
- pass
|