You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

filters.py 12KB


  1. """
  2. Provides generic filtering backends that can be used to filter the results
  3. returned by list views.
  4. """
  5. import operator
  6. from functools import reduce
  7. from django.core.exceptions import ImproperlyConfigured
  8. from django.db import models
  9. from django.db.models.constants import LOOKUP_SEP
  10. from django.db.models.sql.constants import ORDER_PATTERN
  11. from django.template import loader
  12. from django.utils.encoding import force_str
  13. from django.utils.translation import gettext_lazy as _
  14. from rest_framework.compat import coreapi, coreschema, distinct
  15. from rest_framework.settings import api_settings
  16. class BaseFilterBackend:
  17. """
  18. A base class from which all filter backend classes should inherit.
  19. """
  20. def filter_queryset(self, request, queryset, view):
  21. """
  22. Return a filtered queryset.
  23. """
  24. raise NotImplementedError(".filter_queryset() must be overridden.")
  25. def get_schema_fields(self, view):
  26. assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
  27. assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
  28. return []
  29. def get_schema_operation_parameters(self, view):
  30. return []
  31. class SearchFilter(BaseFilterBackend):
  32. # The URL query parameter used for the search.
  33. search_param = api_settings.SEARCH_PARAM
  34. template = 'rest_framework/filters/search.html'
  35. lookup_prefixes = {
  36. '^': 'istartswith',
  37. '=': 'iexact',
  38. '@': 'search',
  39. '$': 'iregex',
  40. }
  41. search_title = _('Search')
  42. search_description = _('A search term.')
  43. def get_search_fields(self, view, request):
  44. """
  45. Search fields are obtained from the view, but the request is always
  46. passed to this method. Sub-classes can override this method to
  47. dynamically change the search fields based on request content.
  48. """
  49. return getattr(view, 'search_fields', None)
  50. def get_search_terms(self, request):
  51. """
  52. Search terms are set by a ?search=... query parameter,
  53. and may be comma and/or whitespace delimited.
  54. """
  55. params = request.query_params.get(self.search_param, '')
  56. params = params.replace('\x00', '') # strip null characters
  57. params = params.replace(',', ' ')
  58. return params.split()
  59. def construct_search(self, field_name):
  60. lookup = self.lookup_prefixes.get(field_name[0])
  61. if lookup:
  62. field_name = field_name[1:]
  63. else:
  64. lookup = 'icontains'
  65. return LOOKUP_SEP.join([field_name, lookup])
  66. def must_call_distinct(self, queryset, search_fields):
  67. """
  68. Return True if 'distinct()' should be used to query the given lookups.
  69. """
  70. for search_field in search_fields:
  71. opts = queryset.model._meta
  72. if search_field[0] in self.lookup_prefixes:
  73. search_field = search_field[1:]
  74. # Annotated fields do not need to be distinct
  75. if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations:
  76. return False
  77. parts = search_field.split(LOOKUP_SEP)
  78. for part in parts:
  79. field = opts.get_field(part)
  80. if hasattr(field, 'get_path_info'):
  81. # This field is a relation, update opts to follow the relation
  82. path_info = field.get_path_info()
  83. opts = path_info[-1].to_opts
  84. if any(path.m2m for path in path_info):
  85. # This field is a m2m relation so we know we need to call distinct
  86. return True
  87. return False
  88. def filter_queryset(self, request, queryset, view):
  89. search_fields = self.get_search_fields(view, request)
  90. search_terms = self.get_search_terms(request)
  91. if not search_fields or not search_terms:
  92. return queryset
  93. orm_lookups = [
  94. self.construct_search(str(search_field))
  95. for search_field in search_fields
  96. ]
  97. base = queryset
  98. conditions = []
  99. for search_term in search_terms:
  100. queries = [
  101. models.Q(**{orm_lookup: search_term})
  102. for orm_lookup in orm_lookups
  103. ]
  104. conditions.append(reduce(operator.or_, queries))
  105. queryset = queryset.filter(reduce(operator.and_, conditions))
  106. if self.must_call_distinct(queryset, search_fields):
  107. # Filtering against a many-to-many field requires us to
  108. # call queryset.distinct() in order to avoid duplicate items
  109. # in the resulting queryset.
  110. # We try to avoid this if possible, for performance reasons.
  111. queryset = distinct(queryset, base)
  112. return queryset
  113. def to_html(self, request, queryset, view):
  114. if not getattr(view, 'search_fields', None):
  115. return ''
  116. term = self.get_search_terms(request)
  117. term = term[0] if term else ''
  118. context = {
  119. 'param': self.search_param,
  120. 'term': term
  121. }
  122. template = loader.get_template(self.template)
  123. return template.render(context)
  124. def get_schema_fields(self, view):
  125. assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
  126. assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
  127. return [
  128. coreapi.Field(
  129. name=self.search_param,
  130. required=False,
  131. location='query',
  132. schema=coreschema.String(
  133. title=force_str(self.search_title),
  134. description=force_str(self.search_description)
  135. )
  136. )
  137. ]
  138. def get_schema_operation_parameters(self, view):
  139. return [
  140. {
  141. 'name': self.search_param,
  142. 'required': False,
  143. 'in': 'query',
  144. 'description': force_str(self.search_description),
  145. 'schema': {
  146. 'type': 'string',
  147. },
  148. },
  149. ]
  150. class OrderingFilter(BaseFilterBackend):
  151. # The URL query parameter used for the ordering.
  152. ordering_param = api_settings.ORDERING_PARAM
  153. ordering_fields = None
  154. ordering_title = _('Ordering')
  155. ordering_description = _('Which field to use when ordering the results.')
  156. template = 'rest_framework/filters/ordering.html'
  157. def get_ordering(self, request, queryset, view):
  158. """
  159. Ordering is set by a comma delimited ?ordering=... query parameter.
  160. The `ordering` query parameter can be overridden by setting
  161. the `ordering_param` value on the OrderingFilter or by
  162. specifying an `ORDERING_PARAM` value in the API settings.
  163. """
  164. params = request.query_params.get(self.ordering_param)
  165. if params:
  166. fields = [param.strip() for param in params.split(',')]
  167. ordering = self.remove_invalid_fields(queryset, fields, view, request)
  168. if ordering:
  169. return ordering
  170. # No ordering was included, or all the ordering fields were invalid
  171. return self.get_default_ordering(view)
  172. def get_default_ordering(self, view):
  173. ordering = getattr(view, 'ordering', None)
  174. if isinstance(ordering, str):
  175. return (ordering,)
  176. return ordering
  177. def get_default_valid_fields(self, queryset, view, context={}):
  178. # If `ordering_fields` is not specified, then we determine a default
  179. # based on the serializer class, if one exists on the view.
  180. if hasattr(view, 'get_serializer_class'):
  181. try:
  182. serializer_class = view.get_serializer_class()
  183. except AssertionError:
  184. # Raised by the default implementation if
  185. # no serializer_class was found
  186. serializer_class = None
  187. else:
  188. serializer_class = getattr(view, 'serializer_class', None)
  189. if serializer_class is None:
  190. msg = (
  191. "Cannot use %s on a view which does not have either a "
  192. "'serializer_class', an overriding 'get_serializer_class' "
  193. "or 'ordering_fields' attribute."
  194. )
  195. raise ImproperlyConfigured(msg % self.__class__.__name__)
  196. return [
  197. (field.source.replace('.', '__') or field_name, field.label)
  198. for field_name, field in serializer_class(context=context).fields.items()
  199. if not getattr(field, 'write_only', False) and not field.source == '*'
  200. ]
  201. def get_valid_fields(self, queryset, view, context={}):
  202. valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
  203. if valid_fields is None:
  204. # Default to allowing filtering on serializer fields
  205. return self.get_default_valid_fields(queryset, view, context)
  206. elif valid_fields == '__all__':
  207. # View explicitly allows filtering on any model field
  208. valid_fields = [
  209. (field.name, field.verbose_name) for field in queryset.model._meta.fields
  210. ]
  211. valid_fields += [
  212. (key, key.title().split('__'))
  213. for key in queryset.query.annotations
  214. ]
  215. else:
  216. valid_fields = [
  217. (item, item) if isinstance(item, str) else item
  218. for item in valid_fields
  219. ]
  220. return valid_fields
  221. def remove_invalid_fields(self, queryset, fields, view, request):
  222. valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
  223. return [term for term in fields if term.lstrip('-') in valid_fields and ORDER_PATTERN.match(term)]
  224. def filter_queryset(self, request, queryset, view):
  225. ordering = self.get_ordering(request, queryset, view)
  226. if ordering:
  227. return queryset.order_by(*ordering)
  228. return queryset
  229. def get_template_context(self, request, queryset, view):
  230. current = self.get_ordering(request, queryset, view)
  231. current = None if not current else current[0]
  232. options = []
  233. context = {
  234. 'request': request,
  235. 'current': current,
  236. 'param': self.ordering_param,
  237. }
  238. for key, label in self.get_valid_fields(queryset, view, context):
  239. options.append((key, '%s - %s' % (label, _('ascending'))))
  240. options.append(('-' + key, '%s - %s' % (label, _('descending'))))
  241. context['options'] = options
  242. return context
  243. def to_html(self, request, queryset, view):
  244. template = loader.get_template(self.template)
  245. context = self.get_template_context(request, queryset, view)
  246. return template.render(context)
  247. def get_schema_fields(self, view):
  248. assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
  249. assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
  250. return [
  251. coreapi.Field(
  252. name=self.ordering_param,
  253. required=False,
  254. location='query',
  255. schema=coreschema.String(
  256. title=force_str(self.ordering_title),
  257. description=force_str(self.ordering_description)
  258. )
  259. )
  260. ]
  261. def get_schema_operation_parameters(self, view):
  262. return [
  263. {
  264. 'name': self.ordering_param,
  265. 'required': False,
  266. 'in': 'query',
  267. 'description': force_str(self.ordering_description),
  268. 'schema': {
  269. 'type': 'string',
  270. },
  271. },
  272. ]