Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

search.py 8.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. from django.db.models import Field, FloatField
  2. from django.db.models.expressions import CombinedExpression, Func, Value
  3. from django.db.models.functions import Coalesce
  4. from django.db.models.lookups import Lookup
  5. class SearchVectorExact(Lookup):
  6. lookup_name = 'exact'
  7. def process_rhs(self, qn, connection):
  8. if not hasattr(self.rhs, 'resolve_expression'):
  9. config = getattr(self.lhs, 'config', None)
  10. self.rhs = SearchQuery(self.rhs, config=config)
  11. rhs, rhs_params = super().process_rhs(qn, connection)
  12. return rhs, rhs_params
  13. def as_sql(self, qn, connection):
  14. lhs, lhs_params = self.process_lhs(qn, connection)
  15. rhs, rhs_params = self.process_rhs(qn, connection)
  16. params = lhs_params + rhs_params
  17. return '%s @@ %s = true' % (lhs, rhs), params
  18. class SearchVectorField(Field):
  19. def db_type(self, connection):
  20. return 'tsvector'
  21. class SearchQueryField(Field):
  22. def db_type(self, connection):
  23. return 'tsquery'
  24. class SearchVectorCombinable:
  25. ADD = '||'
  26. def _combine(self, other, connector, reversed):
  27. if not isinstance(other, SearchVectorCombinable) or not self.config == other.config:
  28. raise TypeError('SearchVector can only be combined with other SearchVectors')
  29. if reversed:
  30. return CombinedSearchVector(other, connector, self, self.config)
  31. return CombinedSearchVector(self, connector, other, self.config)
  32. class SearchVector(SearchVectorCombinable, Func):
  33. function = 'to_tsvector'
  34. arg_joiner = " || ' ' || "
  35. output_field = SearchVectorField()
  36. config = None
  37. def __init__(self, *expressions, **extra):
  38. super().__init__(*expressions, **extra)
  39. self.source_expressions = [
  40. Coalesce(expression, Value('')) for expression in self.source_expressions
  41. ]
  42. self.config = self.extra.get('config', self.config)
  43. weight = self.extra.get('weight')
  44. if weight is not None and not hasattr(weight, 'resolve_expression'):
  45. weight = Value(weight)
  46. self.weight = weight
  47. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  48. resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
  49. if self.config:
  50. if not hasattr(self.config, 'resolve_expression'):
  51. resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
  52. else:
  53. resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
  54. return resolved
  55. def as_sql(self, compiler, connection, function=None, template=None):
  56. config_params = []
  57. if template is None:
  58. if self.config:
  59. config_sql, config_params = compiler.compile(self.config)
  60. template = "%(function)s({}::regconfig, %(expressions)s)".format(config_sql.replace('%', '%%'))
  61. else:
  62. template = self.template
  63. sql, params = super().as_sql(compiler, connection, function=function, template=template)
  64. extra_params = []
  65. if self.weight:
  66. weight_sql, extra_params = compiler.compile(self.weight)
  67. sql = 'setweight({}, {})'.format(sql, weight_sql)
  68. return sql, config_params + params + extra_params
  69. class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
  70. def __init__(self, lhs, connector, rhs, config, output_field=None):
  71. self.config = config
  72. super().__init__(lhs, connector, rhs, output_field)
  73. class SearchQueryCombinable:
  74. BITAND = '&&'
  75. BITOR = '||'
  76. def _combine(self, other, connector, reversed):
  77. if not isinstance(other, SearchQueryCombinable):
  78. raise TypeError(
  79. 'SearchQuery can only be combined with other SearchQuerys, '
  80. 'got {}.'.format(type(other))
  81. )
  82. if not self.config == other.config:
  83. raise TypeError("SearchQuery configs don't match.")
  84. if reversed:
  85. return CombinedSearchQuery(other, connector, self, self.config)
  86. return CombinedSearchQuery(self, connector, other, self.config)
  87. # On Combinable, these are not implemented to reduce confusion with Q. In
  88. # this case we are actually (ab)using them to do logical combination so
  89. # it's consistent with other usage in Django.
  90. def __or__(self, other):
  91. return self._combine(other, self.BITOR, False)
  92. def __ror__(self, other):
  93. return self._combine(other, self.BITOR, True)
  94. def __and__(self, other):
  95. return self._combine(other, self.BITAND, False)
  96. def __rand__(self, other):
  97. return self._combine(other, self.BITAND, True)
  98. class SearchQuery(SearchQueryCombinable, Value):
  99. output_field = SearchQueryField()
  100. def __init__(self, value, output_field=None, *, config=None, invert=False):
  101. self.config = config
  102. self.invert = invert
  103. super().__init__(value, output_field=output_field)
  104. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  105. resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
  106. if self.config:
  107. if not hasattr(self.config, 'resolve_expression'):
  108. resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
  109. else:
  110. resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
  111. return resolved
  112. def as_sql(self, compiler, connection):
  113. params = [self.value]
  114. if self.config:
  115. config_sql, config_params = compiler.compile(self.config)
  116. template = 'plainto_tsquery({}::regconfig, %s)'.format(config_sql)
  117. params = config_params + [self.value]
  118. else:
  119. template = 'plainto_tsquery(%s)'
  120. if self.invert:
  121. template = '!!({})'.format(template)
  122. return template, params
  123. def _combine(self, other, connector, reversed):
  124. combined = super()._combine(other, connector, reversed)
  125. combined.output_field = SearchQueryField()
  126. return combined
  127. def __invert__(self):
  128. return type(self)(self.value, config=self.config, invert=not self.invert)
  129. class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression):
  130. def __init__(self, lhs, connector, rhs, config, output_field=None):
  131. self.config = config
  132. super().__init__(lhs, connector, rhs, output_field)
  133. class SearchRank(Func):
  134. function = 'ts_rank'
  135. output_field = FloatField()
  136. def __init__(self, vector, query, **extra):
  137. if not hasattr(vector, 'resolve_expression'):
  138. vector = SearchVector(vector)
  139. if not hasattr(query, 'resolve_expression'):
  140. query = SearchQuery(query)
  141. weights = extra.get('weights')
  142. if weights is not None and not hasattr(weights, 'resolve_expression'):
  143. weights = Value(weights)
  144. self.weights = weights
  145. super().__init__(vector, query, **extra)
  146. def as_sql(self, compiler, connection, function=None, template=None):
  147. extra_params = []
  148. extra_context = {}
  149. if template is None and self.extra.get('weights'):
  150. if self.weights:
  151. template = '%(function)s(%(weights)s, %(expressions)s)'
  152. weight_sql, extra_params = compiler.compile(self.weights)
  153. extra_context['weights'] = weight_sql
  154. sql, params = super().as_sql(
  155. compiler, connection,
  156. function=function, template=template, **extra_context
  157. )
  158. return sql, extra_params + params
  159. SearchVectorField.register_lookup(SearchVectorExact)
  160. class TrigramBase(Func):
  161. output_field = FloatField()
  162. def __init__(self, expression, string, **extra):
  163. if not hasattr(string, 'resolve_expression'):
  164. string = Value(string)
  165. super().__init__(expression, string, **extra)
  166. class TrigramSimilarity(TrigramBase):
  167. function = 'SIMILARITY'
  168. class TrigramDistance(TrigramBase):
  169. function = ''
  170. arg_joiner = ' <-> '