You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lookups.py 19KB

5 years ago

  1. import itertools
  2. import math
  3. from copy import copy
  4. from django.core.exceptions import EmptyResultSet
  5. from django.db.models.expressions import Func, Value
  6. from django.db.models.fields import DateTimeField, Field, IntegerField
  7. from django.db.models.query_utils import RegisterLookupMixin
  8. from django.utils.datastructures import OrderedSet
  9. from django.utils.functional import cached_property
  10. class Lookup:
  11. lookup_name = None
  12. prepare_rhs = True
  13. can_use_none_as_rhs = False
  14. def __init__(self, lhs, rhs):
  15. self.lhs, self.rhs = lhs, rhs
  16. self.rhs = self.get_prep_lookup()
  17. if hasattr(self.lhs, 'get_bilateral_transforms'):
  18. bilateral_transforms = self.lhs.get_bilateral_transforms()
  19. else:
  20. bilateral_transforms = []
  21. if bilateral_transforms:
  22. # Warn the user as soon as possible if they are trying to apply
  23. # a bilateral transformation on a nested QuerySet: that won't work.
  24. from django.db.models.sql.query import Query # avoid circular import
  25. if isinstance(rhs, Query):
  26. raise NotImplementedError("Bilateral transformations on nested querysets are not implemented.")
  27. self.bilateral_transforms = bilateral_transforms
  28. def apply_bilateral_transforms(self, value):
  29. for transform in self.bilateral_transforms:
  30. value = transform(value)
  31. return value
  32. def batch_process_rhs(self, compiler, connection, rhs=None):
  33. if rhs is None:
  34. rhs = self.rhs
  35. if self.bilateral_transforms:
  36. sqls, sqls_params = [], []
  37. for p in rhs:
  38. value = Value(p, output_field=self.lhs.output_field)
  39. value = self.apply_bilateral_transforms(value)
  40. value = value.resolve_expression(compiler.query)
  41. sql, sql_params = compiler.compile(value)
  42. sqls.append(sql)
  43. sqls_params.extend(sql_params)
  44. else:
  45. _, params = self.get_db_prep_lookup(rhs, connection)
  46. sqls, sqls_params = ['%s'] * len(params), params
  47. return sqls, sqls_params
  48. def get_source_expressions(self):
  49. if self.rhs_is_direct_value():
  50. return [self.lhs]
  51. return [self.lhs, self.rhs]
  52. def set_source_expressions(self, new_exprs):
  53. if len(new_exprs) == 1:
  54. self.lhs = new_exprs[0]
  55. else:
  56. self.lhs, self.rhs = new_exprs
  57. def get_prep_lookup(self):
  58. if hasattr(self.rhs, '_prepare'):
  59. return self.rhs._prepare(self.lhs.output_field)
  60. if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
  61. return self.lhs.output_field.get_prep_value(self.rhs)
  62. return self.rhs
  63. def get_db_prep_lookup(self, value, connection):
  64. return ('%s', [value])
  65. def process_lhs(self, compiler, connection, lhs=None):
  66. lhs = lhs or self.lhs
  67. if hasattr(lhs, 'resolve_expression'):
  68. lhs = lhs.resolve_expression(compiler.query)
  69. return compiler.compile(lhs)
  70. def process_rhs(self, compiler, connection):
  71. value = self.rhs
  72. if self.bilateral_transforms:
  73. if self.rhs_is_direct_value():
  74. # Do not call get_db_prep_lookup here as the value will be
  75. # transformed before being used for lookup
  76. value = Value(value, output_field=self.lhs.output_field)
  77. value = self.apply_bilateral_transforms(value)
  78. value = value.resolve_expression(compiler.query)
  79. if hasattr(value, 'as_sql'):
  80. sql, params = compiler.compile(value)
  81. return '(' + sql + ')', params
  82. else:
  83. return self.get_db_prep_lookup(value, connection)
  84. def rhs_is_direct_value(self):
  85. return not hasattr(self.rhs, 'as_sql')
  86. def relabeled_clone(self, relabels):
  87. new = copy(self)
  88. new.lhs = new.lhs.relabeled_clone(relabels)
  89. if hasattr(new.rhs, 'relabeled_clone'):
  90. new.rhs = new.rhs.relabeled_clone(relabels)
  91. return new
  92. def get_group_by_cols(self):
  93. cols = self.lhs.get_group_by_cols()
  94. if hasattr(self.rhs, 'get_group_by_cols'):
  95. cols.extend(self.rhs.get_group_by_cols())
  96. return cols
  97. def as_sql(self, compiler, connection):
  98. raise NotImplementedError
  99. @cached_property
  100. def contains_aggregate(self):
  101. return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
  102. @cached_property
  103. def contains_over_clause(self):
  104. return self.lhs.contains_over_clause or getattr(self.rhs, 'contains_over_clause', False)
  105. @property
  106. def is_summary(self):
  107. return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False)
  108. class Transform(RegisterLookupMixin, Func):
  109. """
  110. RegisterLookupMixin() is first so that get_lookup() and get_transform()
  111. first examine self and then check output_field.
  112. """
  113. bilateral = False
  114. arity = 1
  115. @property
  116. def lhs(self):
  117. return self.get_source_expressions()[0]
  118. def get_bilateral_transforms(self):
  119. if hasattr(self.lhs, 'get_bilateral_transforms'):
  120. bilateral_transforms = self.lhs.get_bilateral_transforms()
  121. else:
  122. bilateral_transforms = []
  123. if self.bilateral:
  124. bilateral_transforms.append(self.__class__)
  125. return bilateral_transforms
  126. class BuiltinLookup(Lookup):
  127. def process_lhs(self, compiler, connection, lhs=None):
  128. lhs_sql, params = super().process_lhs(compiler, connection, lhs)
  129. field_internal_type = self.lhs.output_field.get_internal_type()
  130. db_type = self.lhs.output_field.db_type(connection=connection)
  131. lhs_sql = connection.ops.field_cast_sql(
  132. db_type, field_internal_type) % lhs_sql
  133. lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
  134. return lhs_sql, list(params)
  135. def as_sql(self, compiler, connection):
  136. lhs_sql, params = self.process_lhs(compiler, connection)
  137. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  138. params.extend(rhs_params)
  139. rhs_sql = self.get_rhs_op(connection, rhs_sql)
  140. return '%s %s' % (lhs_sql, rhs_sql), params
  141. def get_rhs_op(self, connection, rhs):
  142. return connection.operators[self.lookup_name] % rhs
  143. class FieldGetDbPrepValueMixin:
  144. """
  145. Some lookups require Field.get_db_prep_value() to be called on their
  146. inputs.
  147. """
  148. get_db_prep_lookup_value_is_iterable = False
  149. def get_db_prep_lookup(self, value, connection):
  150. # For relational fields, use the output_field of the 'field' attribute.
  151. field = getattr(self.lhs.output_field, 'field', None)
  152. get_db_prep_value = getattr(field, 'get_db_prep_value', None) or self.lhs.output_field.get_db_prep_value
  153. return (
  154. '%s',
  155. [get_db_prep_value(v, connection, prepared=True) for v in value]
  156. if self.get_db_prep_lookup_value_is_iterable else
  157. [get_db_prep_value(value, connection, prepared=True)]
  158. )
  159. class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
  160. """
  161. Some lookups require Field.get_db_prep_value() to be called on each value
  162. in an iterable.
  163. """
  164. get_db_prep_lookup_value_is_iterable = True
  165. def get_prep_lookup(self):
  166. prepared_values = []
  167. if hasattr(self.rhs, '_prepare'):
  168. # A subquery is like an iterable but its items shouldn't be
  169. # prepared independently.
  170. return self.rhs._prepare(self.lhs.output_field)
  171. for rhs_value in self.rhs:
  172. if hasattr(rhs_value, 'resolve_expression'):
  173. # An expression will be handled by the database but can coexist
  174. # alongside real values.
  175. pass
  176. elif self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
  177. rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
  178. prepared_values.append(rhs_value)
  179. return prepared_values
  180. def process_rhs(self, compiler, connection):
  181. if self.rhs_is_direct_value():
  182. # rhs should be an iterable of values. Use batch_process_rhs()
  183. # to prepare/transform those values.
  184. return self.batch_process_rhs(compiler, connection)
  185. else:
  186. return super().process_rhs(compiler, connection)
  187. def resolve_expression_parameter(self, compiler, connection, sql, param):
  188. params = [param]
  189. if hasattr(param, 'resolve_expression'):
  190. param = param.resolve_expression(compiler.query)
  191. if hasattr(param, 'as_sql'):
  192. sql, params = param.as_sql(compiler, connection)
  193. return sql, params
  194. def batch_process_rhs(self, compiler, connection, rhs=None):
  195. pre_processed = super().batch_process_rhs(compiler, connection, rhs)
  196. # The params list may contain expressions which compile to a
  197. # sql/param pair. Zip them to get sql and param pairs that refer to the
  198. # same argument and attempt to replace them with the result of
  199. # compiling the param step.
  200. sql, params = zip(*(
  201. self.resolve_expression_parameter(compiler, connection, sql, param)
  202. for sql, param in zip(*pre_processed)
  203. ))
  204. params = itertools.chain.from_iterable(params)
  205. return sql, tuple(params)
  206. @Field.register_lookup
  207. class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
  208. lookup_name = 'exact'
  209. def process_rhs(self, compiler, connection):
  210. from django.db.models.sql.query import Query
  211. if isinstance(self.rhs, Query):
  212. if self.rhs.has_limit_one():
  213. # The subquery must select only the pk.
  214. self.rhs.clear_select_clause()
  215. self.rhs.add_fields(['pk'])
  216. else:
  217. raise ValueError(
  218. 'The QuerySet value for an exact lookup must be limited to '
  219. 'one result using slicing.'
  220. )
  221. return super().process_rhs(compiler, connection)
  222. @Field.register_lookup
  223. class IExact(BuiltinLookup):
  224. lookup_name = 'iexact'
  225. prepare_rhs = False
  226. def process_rhs(self, qn, connection):
  227. rhs, params = super().process_rhs(qn, connection)
  228. if params:
  229. params[0] = connection.ops.prep_for_iexact_query(params[0])
  230. return rhs, params
  231. @Field.register_lookup
  232. class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
  233. lookup_name = 'gt'
  234. @Field.register_lookup
  235. class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
  236. lookup_name = 'gte'
  237. @Field.register_lookup
  238. class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
  239. lookup_name = 'lt'
  240. @Field.register_lookup
  241. class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
  242. lookup_name = 'lte'
  243. class IntegerFieldFloatRounding:
  244. """
  245. Allow floats to work as query values for IntegerField. Without this, the
  246. decimal portion of the float would always be discarded.
  247. """
  248. def get_prep_lookup(self):
  249. if isinstance(self.rhs, float):
  250. self.rhs = math.ceil(self.rhs)
  251. return super().get_prep_lookup()
  252. @IntegerField.register_lookup
  253. class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual):
  254. pass
  255. @IntegerField.register_lookup
  256. class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
  257. pass
  258. @Field.register_lookup
  259. class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
  260. lookup_name = 'in'
  261. def process_rhs(self, compiler, connection):
  262. db_rhs = getattr(self.rhs, '_db', None)
  263. if db_rhs is not None and db_rhs != connection.alias:
  264. raise ValueError(
  265. "Subqueries aren't allowed across different databases. Force "
  266. "the inner query to be evaluated using `list(inner_query)`."
  267. )
  268. if self.rhs_is_direct_value():
  269. try:
  270. rhs = OrderedSet(self.rhs)
  271. except TypeError: # Unhashable items in self.rhs
  272. rhs = self.rhs
  273. if not rhs:
  274. raise EmptyResultSet
  275. # rhs should be an iterable; use batch_process_rhs() to
  276. # prepare/transform those values.
  277. sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
  278. placeholder = '(' + ', '.join(sqls) + ')'
  279. return (placeholder, sqls_params)
  280. else:
  281. if not getattr(self.rhs, 'has_select_fields', True):
  282. self.rhs.clear_select_clause()
  283. self.rhs.add_fields(['pk'])
  284. return super().process_rhs(compiler, connection)
  285. def get_rhs_op(self, connection, rhs):
  286. return 'IN %s' % rhs
  287. def as_sql(self, compiler, connection):
  288. max_in_list_size = connection.ops.max_in_list_size()
  289. if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size:
  290. return self.split_parameter_list_as_sql(compiler, connection)
  291. return super().as_sql(compiler, connection)
  292. def split_parameter_list_as_sql(self, compiler, connection):
  293. # This is a special case for databases which limit the number of
  294. # elements which can appear in an 'IN' clause.
  295. max_in_list_size = connection.ops.max_in_list_size()
  296. lhs, lhs_params = self.process_lhs(compiler, connection)
  297. rhs, rhs_params = self.batch_process_rhs(compiler, connection)
  298. in_clause_elements = ['(']
  299. params = []
  300. for offset in range(0, len(rhs_params), max_in_list_size):
  301. if offset > 0:
  302. in_clause_elements.append(' OR ')
  303. in_clause_elements.append('%s IN (' % lhs)
  304. params.extend(lhs_params)
  305. sqls = rhs[offset: offset + max_in_list_size]
  306. sqls_params = rhs_params[offset: offset + max_in_list_size]
  307. param_group = ', '.join(sqls)
  308. in_clause_elements.append(param_group)
  309. in_clause_elements.append(')')
  310. params.extend(sqls_params)
  311. in_clause_elements.append(')')
  312. return ''.join(in_clause_elements), params
  313. class PatternLookup(BuiltinLookup):
  314. param_pattern = '%%%s%%'
  315. prepare_rhs = False
  316. def get_rhs_op(self, connection, rhs):
  317. # Assume we are in startswith. We need to produce SQL like:
  318. # col LIKE %s, ['thevalue%']
  319. # For python values we can (and should) do that directly in Python,
  320. # but if the value is for example reference to other column, then
  321. # we need to add the % pattern match to the lookup by something like
  322. # col LIKE othercol || '%%'
  323. # So, for Python values we don't need any special pattern, but for
  324. # SQL reference values or SQL transformations we need the correct
  325. # pattern added.
  326. if hasattr(self.rhs, 'as_sql') or self.bilateral_transforms:
  327. pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc)
  328. return pattern.format(rhs)
  329. else:
  330. return super().get_rhs_op(connection, rhs)
  331. def process_rhs(self, qn, connection):
  332. rhs, params = super().process_rhs(qn, connection)
  333. if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
  334. params[0] = self.param_pattern % connection.ops.prep_for_like_query(params[0])
  335. return rhs, params
  336. @Field.register_lookup
  337. class Contains(PatternLookup):
  338. lookup_name = 'contains'
  339. @Field.register_lookup
  340. class IContains(Contains):
  341. lookup_name = 'icontains'
  342. @Field.register_lookup
  343. class StartsWith(PatternLookup):
  344. lookup_name = 'startswith'
  345. param_pattern = '%s%%'
  346. @Field.register_lookup
  347. class IStartsWith(StartsWith):
  348. lookup_name = 'istartswith'
  349. @Field.register_lookup
  350. class EndsWith(PatternLookup):
  351. lookup_name = 'endswith'
  352. param_pattern = '%%%s'
  353. @Field.register_lookup
  354. class IEndsWith(EndsWith):
  355. lookup_name = 'iendswith'
  356. @Field.register_lookup
  357. class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
  358. lookup_name = 'range'
  359. def get_rhs_op(self, connection, rhs):
  360. return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
  361. @Field.register_lookup
  362. class IsNull(BuiltinLookup):
  363. lookup_name = 'isnull'
  364. prepare_rhs = False
  365. def as_sql(self, compiler, connection):
  366. sql, params = compiler.compile(self.lhs)
  367. if self.rhs:
  368. return "%s IS NULL" % sql, params
  369. else:
  370. return "%s IS NOT NULL" % sql, params
  371. @Field.register_lookup
  372. class Regex(BuiltinLookup):
  373. lookup_name = 'regex'
  374. prepare_rhs = False
  375. def as_sql(self, compiler, connection):
  376. if self.lookup_name in connection.operators:
  377. return super().as_sql(compiler, connection)
  378. else:
  379. lhs, lhs_params = self.process_lhs(compiler, connection)
  380. rhs, rhs_params = self.process_rhs(compiler, connection)
  381. sql_template = connection.ops.regex_lookup(self.lookup_name)
  382. return sql_template % (lhs, rhs), lhs_params + rhs_params
  383. @Field.register_lookup
  384. class IRegex(Regex):
  385. lookup_name = 'iregex'
  386. class YearLookup(Lookup):
  387. def year_lookup_bounds(self, connection, year):
  388. output_field = self.lhs.lhs.output_field
  389. if isinstance(output_field, DateTimeField):
  390. bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
  391. else:
  392. bounds = connection.ops.year_lookup_bounds_for_date_field(year)
  393. return bounds
  394. class YearComparisonLookup(YearLookup):
  395. def as_sql(self, compiler, connection):
  396. # We will need to skip the extract part and instead go
  397. # directly with the originating field, that is self.lhs.lhs.
  398. lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
  399. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  400. rhs_sql = self.get_rhs_op(connection, rhs_sql)
  401. start, finish = self.year_lookup_bounds(connection, rhs_params[0])
  402. params.append(self.get_bound(start, finish))
  403. return '%s %s' % (lhs_sql, rhs_sql), params
  404. def get_rhs_op(self, connection, rhs):
  405. return connection.operators[self.lookup_name] % rhs
  406. def get_bound(self, start, finish):
  407. raise NotImplementedError(
  408. 'subclasses of YearComparisonLookup must provide a get_bound() method'
  409. )
  410. class YearExact(YearLookup, Exact):
  411. lookup_name = 'exact'
  412. def as_sql(self, compiler, connection):
  413. # We will need to skip the extract part and instead go
  414. # directly with the originating field, that is self.lhs.lhs.
  415. lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
  416. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  417. try:
  418. # Check that rhs_params[0] exists (IndexError),
  419. # it isn't None (TypeError), and is a number (ValueError)
  420. int(rhs_params[0])
  421. except (IndexError, TypeError, ValueError):
  422. # Can't determine the bounds before executing the query, so skip
  423. # optimizations by falling back to a standard exact comparison.
  424. return super().as_sql(compiler, connection)
  425. bounds = self.year_lookup_bounds(connection, rhs_params[0])
  426. params.extend(bounds)
  427. return '%s BETWEEN %%s AND %%s' % lhs_sql, params
  428. class YearGt(YearComparisonLookup):
  429. lookup_name = 'gt'
  430. def get_bound(self, start, finish):
  431. return finish
  432. class YearGte(YearComparisonLookup):
  433. lookup_name = 'gte'
  434. def get_bound(self, start, finish):
  435. return start
  436. class YearLt(YearComparisonLookup):
  437. lookup_name = 'lt'
  438. def get_bound(self, start, finish):
  439. return start
  440. class YearLte(YearComparisonLookup):
  441. lookup_name = 'lte'
  442. def get_bound(self, start, finish):
  443. return finish