You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

search.py 8.8KB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. from django.db.models import CharField, Field, FloatField, TextField
  2. from django.db.models.expressions import CombinedExpression, Func, Value
  3. from django.db.models.functions import Cast, Coalesce
  4. from django.db.models.lookups import Lookup
  5. class SearchVectorExact(Lookup):
  6. lookup_name = 'exact'
  7. def process_rhs(self, qn, connection):
  8. if not hasattr(self.rhs, 'resolve_expression'):
  9. config = getattr(self.lhs, 'config', None)
  10. self.rhs = SearchQuery(self.rhs, config=config)
  11. rhs, rhs_params = super().process_rhs(qn, connection)
  12. return rhs, rhs_params
  13. def as_sql(self, qn, connection):
  14. lhs, lhs_params = self.process_lhs(qn, connection)
  15. rhs, rhs_params = self.process_rhs(qn, connection)
  16. params = lhs_params + rhs_params
  17. return '%s @@ %s = true' % (lhs, rhs), params
  18. class SearchVectorField(Field):
  19. def db_type(self, connection):
  20. return 'tsvector'
  21. class SearchQueryField(Field):
  22. def db_type(self, connection):
  23. return 'tsquery'
  24. class SearchVectorCombinable:
  25. ADD = '||'
  26. def _combine(self, other, connector, reversed):
  27. if not isinstance(other, SearchVectorCombinable) or not self.config == other.config:
  28. raise TypeError('SearchVector can only be combined with other SearchVectors')
  29. if reversed:
  30. return CombinedSearchVector(other, connector, self, self.config)
  31. return CombinedSearchVector(self, connector, other, self.config)
  32. class SearchVector(SearchVectorCombinable, Func):
  33. function = 'to_tsvector'
  34. arg_joiner = " || ' ' || "
  35. output_field = SearchVectorField()
  36. config = None
  37. def __init__(self, *expressions, **extra):
  38. super().__init__(*expressions, **extra)
  39. self.config = self.extra.get('config', self.config)
  40. weight = self.extra.get('weight')
  41. if weight is not None and not hasattr(weight, 'resolve_expression'):
  42. weight = Value(weight)
  43. self.weight = weight
  44. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  45. resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
  46. if self.config:
  47. if not hasattr(self.config, 'resolve_expression'):
  48. resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
  49. else:
  50. resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
  51. return resolved
  52. def as_sql(self, compiler, connection, function=None, template=None):
  53. clone = self.copy()
  54. clone.set_source_expressions([
  55. Coalesce(
  56. expression
  57. if isinstance(expression.output_field, (CharField, TextField))
  58. else Cast(expression, TextField()),
  59. Value('')
  60. ) for expression in clone.get_source_expressions()
  61. ])
  62. config_params = []
  63. if template is None:
  64. if clone.config:
  65. config_sql, config_params = compiler.compile(clone.config)
  66. template = '%(function)s({}::regconfig, %(expressions)s)'.format(config_sql.replace('%', '%%'))
  67. else:
  68. template = clone.template
  69. sql, params = super(SearchVector, clone).as_sql(compiler, connection, function=function, template=template)
  70. extra_params = []
  71. if clone.weight:
  72. weight_sql, extra_params = compiler.compile(clone.weight)
  73. sql = 'setweight({}, {})'.format(sql, weight_sql)
  74. return sql, config_params + params + extra_params
  75. class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
  76. def __init__(self, lhs, connector, rhs, config, output_field=None):
  77. self.config = config
  78. super().__init__(lhs, connector, rhs, output_field)
  79. class SearchQueryCombinable:
  80. BITAND = '&&'
  81. BITOR = '||'
  82. def _combine(self, other, connector, reversed):
  83. if not isinstance(other, SearchQueryCombinable):
  84. raise TypeError(
  85. 'SearchQuery can only be combined with other SearchQuerys, '
  86. 'got {}.'.format(type(other))
  87. )
  88. if reversed:
  89. return CombinedSearchQuery(other, connector, self, self.config)
  90. return CombinedSearchQuery(self, connector, other, self.config)
  91. # On Combinable, these are not implemented to reduce confusion with Q. In
  92. # this case we are actually (ab)using them to do logical combination so
  93. # it's consistent with other usage in Django.
  94. def __or__(self, other):
  95. return self._combine(other, self.BITOR, False)
  96. def __ror__(self, other):
  97. return self._combine(other, self.BITOR, True)
  98. def __and__(self, other):
  99. return self._combine(other, self.BITAND, False)
  100. def __rand__(self, other):
  101. return self._combine(other, self.BITAND, True)
  102. class SearchQuery(SearchQueryCombinable, Value):
  103. output_field = SearchQueryField()
  104. SEARCH_TYPES = {
  105. 'plain': 'plainto_tsquery',
  106. 'phrase': 'phraseto_tsquery',
  107. 'raw': 'to_tsquery',
  108. }
  109. def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'):
  110. self.config = config
  111. self.invert = invert
  112. if search_type not in self.SEARCH_TYPES:
  113. raise ValueError("Unknown search_type argument '%s'." % search_type)
  114. self.search_type = search_type
  115. super().__init__(value, output_field=output_field)
  116. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  117. resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
  118. if self.config:
  119. if not hasattr(self.config, 'resolve_expression'):
  120. resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
  121. else:
  122. resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
  123. return resolved
  124. def as_sql(self, compiler, connection):
  125. params = [self.value]
  126. function = self.SEARCH_TYPES[self.search_type]
  127. if self.config:
  128. config_sql, config_params = compiler.compile(self.config)
  129. template = '{}({}::regconfig, %s)'.format(function, config_sql)
  130. params = config_params + [self.value]
  131. else:
  132. template = '{}(%s)'.format(function)
  133. if self.invert:
  134. template = '!!({})'.format(template)
  135. return template, params
  136. def _combine(self, other, connector, reversed):
  137. combined = super()._combine(other, connector, reversed)
  138. combined.output_field = SearchQueryField()
  139. return combined
  140. def __invert__(self):
  141. return type(self)(self.value, config=self.config, invert=not self.invert)
  142. def __str__(self):
  143. result = super().__str__()
  144. return ('~%s' % result) if self.invert else result
  145. class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression):
  146. def __init__(self, lhs, connector, rhs, config, output_field=None):
  147. self.config = config
  148. super().__init__(lhs, connector, rhs, output_field)
  149. def __str__(self):
  150. return '(%s)' % super().__str__()
  151. class SearchRank(Func):
  152. function = 'ts_rank'
  153. output_field = FloatField()
  154. def __init__(self, vector, query, **extra):
  155. if not hasattr(vector, 'resolve_expression'):
  156. vector = SearchVector(vector)
  157. if not hasattr(query, 'resolve_expression'):
  158. query = SearchQuery(query)
  159. weights = extra.get('weights')
  160. if weights is not None and not hasattr(weights, 'resolve_expression'):
  161. weights = Value(weights)
  162. self.weights = weights
  163. super().__init__(vector, query, **extra)
  164. def as_sql(self, compiler, connection, function=None, template=None):
  165. extra_params = []
  166. extra_context = {}
  167. if template is None and self.extra.get('weights'):
  168. if self.weights:
  169. template = '%(function)s(%(weights)s, %(expressions)s)'
  170. weight_sql, extra_params = compiler.compile(self.weights)
  171. extra_context['weights'] = weight_sql
  172. sql, params = super().as_sql(
  173. compiler, connection,
  174. function=function, template=template, **extra_context
  175. )
  176. return sql, extra_params + params
  177. SearchVectorField.register_lookup(SearchVectorExact)
  178. class TrigramBase(Func):
  179. output_field = FloatField()
  180. def __init__(self, expression, string, **extra):
  181. if not hasattr(string, 'resolve_expression'):
  182. string = Value(string)
  183. super().__init__(expression, string, **extra)
  184. class TrigramSimilarity(TrigramBase):
  185. function = 'SIMILARITY'
  186. class TrigramDistance(TrigramBase):
  187. function = ''
  188. arg_joiner = ' <-> '