You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

search.py 8.8KB

5 years ago

  1. from django.db.models import CharField, Field, FloatField, TextField
  2. from django.db.models.expressions import CombinedExpression, Func, Value
  3. from django.db.models.functions import Cast, Coalesce
  4. from django.db.models.lookups import Lookup
  5. class SearchVectorExact(Lookup):
  6. lookup_name = 'exact'
  7. def process_rhs(self, qn, connection):
  8. if not hasattr(self.rhs, 'resolve_expression'):
  9. config = getattr(self.lhs, 'config', None)
  10. self.rhs = SearchQuery(self.rhs, config=config)
  11. rhs, rhs_params = super().process_rhs(qn, connection)
  12. return rhs, rhs_params
  13. def as_sql(self, qn, connection):
  14. lhs, lhs_params = self.process_lhs(qn, connection)
  15. rhs, rhs_params = self.process_rhs(qn, connection)
  16. params = lhs_params + rhs_params
  17. return '%s @@ %s = true' % (lhs, rhs), params
  18. class SearchVectorField(Field):
  19. def db_type(self, connection):
  20. return 'tsvector'
  21. class SearchQueryField(Field):
  22. def db_type(self, connection):
  23. return 'tsquery'
  24. class SearchVectorCombinable:
  25. ADD = '||'
  26. def _combine(self, other, connector, reversed):
  27. if not isinstance(other, SearchVectorCombinable) or not self.config == other.config:
  28. raise TypeError('SearchVector can only be combined with other SearchVectors')
  29. if reversed:
  30. return CombinedSearchVector(other, connector, self, self.config)
  31. return CombinedSearchVector(self, connector, other, self.config)
  32. class SearchVector(SearchVectorCombinable, Func):
  33. function = 'to_tsvector'
  34. arg_joiner = " || ' ' || "
  35. output_field = SearchVectorField()
  36. config = None
  37. def __init__(self, *expressions, **extra):
  38. super().__init__(*expressions, **extra)
  39. self.config = self.extra.get('config', self.config)
  40. weight = self.extra.get('weight')
  41. if weight is not None and not hasattr(weight, 'resolve_expression'):
  42. weight = Value(weight)
  43. self.weight = weight
  44. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  45. resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
  46. if self.config:
  47. if not hasattr(self.config, 'resolve_expression'):
  48. resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
  49. else:
  50. resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
  51. return resolved
  52. def as_sql(self, compiler, connection, function=None, template=None):
  53. clone = self.copy()
  54. clone.set_source_expressions([
  55. Coalesce(
  56. expression
  57. if isinstance(expression.output_field, (CharField, TextField))
  58. else Cast(expression, TextField()),
  59. Value('')
  60. ) for expression in clone.get_source_expressions()
  61. ])
  62. config_params = []
  63. if template is None:
  64. if clone.config:
  65. config_sql, config_params = compiler.compile(clone.config)
  66. template = '%(function)s({}::regconfig, %(expressions)s)'.format(config_sql.replace('%', '%%'))
  67. else:
  68. template = clone.template
  69. sql, params = super(SearchVector, clone).as_sql(compiler, connection, function=function, template=template)
  70. extra_params = []
  71. if clone.weight:
  72. weight_sql, extra_params = compiler.compile(clone.weight)
  73. sql = 'setweight({}, {})'.format(sql, weight_sql)
  74. return sql, config_params + params + extra_params
  75. class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
  76. def __init__(self, lhs, connector, rhs, config, output_field=None):
  77. self.config = config
  78. super().__init__(lhs, connector, rhs, output_field)
  79. class SearchQueryCombinable:
  80. BITAND = '&&'
  81. BITOR = '||'
  82. def _combine(self, other, connector, reversed):
  83. if not isinstance(other, SearchQueryCombinable):
  84. raise TypeError(
  85. 'SearchQuery can only be combined with other SearchQuerys, '
  86. 'got {}.'.format(type(other))
  87. )
  88. if reversed:
  89. return CombinedSearchQuery(other, connector, self, self.config)
  90. return CombinedSearchQuery(self, connector, other, self.config)
  91. # On Combinable, these are not implemented to reduce confusion with Q. In
  92. # this case we are actually (ab)using them to do logical combination so
  93. # it's consistent with other usage in Django.
  94. def __or__(self, other):
  95. return self._combine(other, self.BITOR, False)
  96. def __ror__(self, other):
  97. return self._combine(other, self.BITOR, True)
  98. def __and__(self, other):
  99. return self._combine(other, self.BITAND, False)
  100. def __rand__(self, other):
  101. return self._combine(other, self.BITAND, True)
  102. class SearchQuery(SearchQueryCombinable, Value):
  103. output_field = SearchQueryField()
  104. SEARCH_TYPES = {
  105. 'plain': 'plainto_tsquery',
  106. 'phrase': 'phraseto_tsquery',
  107. 'raw': 'to_tsquery',
  108. }
  109. def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'):
  110. self.config = config
  111. self.invert = invert
  112. if search_type not in self.SEARCH_TYPES:
  113. raise ValueError("Unknown search_type argument '%s'." % search_type)
  114. self.search_type = search_type
  115. super().__init__(value, output_field=output_field)
  116. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  117. resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
  118. if self.config:
  119. if not hasattr(self.config, 'resolve_expression'):
  120. resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
  121. else:
  122. resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
  123. return resolved
  124. def as_sql(self, compiler, connection):
  125. params = [self.value]
  126. function = self.SEARCH_TYPES[self.search_type]
  127. if self.config:
  128. config_sql, config_params = compiler.compile(self.config)
  129. template = '{}({}::regconfig, %s)'.format(function, config_sql)
  130. params = config_params + [self.value]
  131. else:
  132. template = '{}(%s)'.format(function)
  133. if self.invert:
  134. template = '!!({})'.format(template)
  135. return template, params
  136. def _combine(self, other, connector, reversed):
  137. combined = super()._combine(other, connector, reversed)
  138. combined.output_field = SearchQueryField()
  139. return combined
  140. def __invert__(self):
  141. return type(self)(self.value, config=self.config, invert=not self.invert)
  142. def __str__(self):
  143. result = super().__str__()
  144. return ('~%s' % result) if self.invert else result
  145. class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression):
  146. def __init__(self, lhs, connector, rhs, config, output_field=None):
  147. self.config = config
  148. super().__init__(lhs, connector, rhs, output_field)
  149. def __str__(self):
  150. return '(%s)' % super().__str__()
  151. class SearchRank(Func):
  152. function = 'ts_rank'
  153. output_field = FloatField()
  154. def __init__(self, vector, query, **extra):
  155. if not hasattr(vector, 'resolve_expression'):
  156. vector = SearchVector(vector)
  157. if not hasattr(query, 'resolve_expression'):
  158. query = SearchQuery(query)
  159. weights = extra.get('weights')
  160. if weights is not None and not hasattr(weights, 'resolve_expression'):
  161. weights = Value(weights)
  162. self.weights = weights
  163. super().__init__(vector, query, **extra)
  164. def as_sql(self, compiler, connection, function=None, template=None):
  165. extra_params = []
  166. extra_context = {}
  167. if template is None and self.extra.get('weights'):
  168. if self.weights:
  169. template = '%(function)s(%(weights)s, %(expressions)s)'
  170. weight_sql, extra_params = compiler.compile(self.weights)
  171. extra_context['weights'] = weight_sql
  172. sql, params = super().as_sql(
  173. compiler, connection,
  174. function=function, template=template, **extra_context
  175. )
  176. return sql, extra_params + params
  177. SearchVectorField.register_lookup(SearchVectorExact)
  178. class TrigramBase(Func):
  179. output_field = FloatField()
  180. def __init__(self, expression, string, **extra):
  181. if not hasattr(string, 'resolve_expression'):
  182. string = Value(string)
  183. super().__init__(expression, string, **extra)
  184. class TrigramSimilarity(TrigramBase):
  185. function = 'SIMILARITY'
  186. class TrigramDistance(TrigramBase):
  187. function = ''
  188. arg_joiner = ' <-> '