Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

related_lookups.py 6.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. from django.db.models.lookups import (
  2. Exact, GreaterThan, GreaterThanOrEqual, In, IsNull, LessThan,
  3. LessThanOrEqual,
  4. )
  5. class MultiColSource:
  6. contains_aggregate = False
  7. def __init__(self, alias, targets, sources, field):
  8. self.targets, self.sources, self.field, self.alias = targets, sources, field, alias
  9. self.output_field = self.field
  10. def __repr__(self):
  11. return "{}({}, {})".format(
  12. self.__class__.__name__, self.alias, self.field)
  13. def relabeled_clone(self, relabels):
  14. return self.__class__(relabels.get(self.alias, self.alias),
  15. self.targets, self.sources, self.field)
  16. def get_lookup(self, lookup):
  17. return self.output_field.get_lookup(lookup)
  18. def get_normalized_value(value, lhs):
  19. from django.db.models import Model
  20. if isinstance(value, Model):
  21. value_list = []
  22. sources = lhs.output_field.get_path_info()[-1].target_fields
  23. for source in sources:
  24. while not isinstance(value, source.model) and source.remote_field:
  25. source = source.remote_field.model._meta.get_field(source.remote_field.field_name)
  26. try:
  27. value_list.append(getattr(value, source.attname))
  28. except AttributeError:
  29. # A case like Restaurant.objects.filter(place=restaurant_instance),
  30. # where place is a OneToOneField and the primary key of Restaurant.
  31. return (value.pk,)
  32. return tuple(value_list)
  33. if not isinstance(value, tuple):
  34. return (value,)
  35. return value
  36. class RelatedIn(In):
  37. def get_prep_lookup(self):
  38. if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value():
  39. # If we get here, we are dealing with single-column relations.
  40. self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
  41. # We need to run the related field's get_prep_value(). Consider case
  42. # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
  43. # doesn't have validation for non-integers, so we must run validation
  44. # using the target field.
  45. if hasattr(self.lhs.output_field, 'get_path_info'):
  46. # Run the target field's get_prep_value. We can safely assume there is
  47. # only one as we don't get to the direct value branch otherwise.
  48. target_field = self.lhs.output_field.get_path_info()[-1].target_fields[-1]
  49. self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
  50. return super().get_prep_lookup()
  51. def as_sql(self, compiler, connection):
  52. if isinstance(self.lhs, MultiColSource):
  53. # For multicolumn lookups we need to build a multicolumn where clause.
  54. # This clause is either a SubqueryConstraint (for values that need to be compiled to
  55. # SQL) or an OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses.
  56. from django.db.models.sql.where import WhereNode, SubqueryConstraint, AND, OR
  57. root_constraint = WhereNode(connector=OR)
  58. if self.rhs_is_direct_value():
  59. values = [get_normalized_value(value, self.lhs) for value in self.rhs]
  60. for value in values:
  61. value_constraint = WhereNode()
  62. for source, target, val in zip(self.lhs.sources, self.lhs.targets, value):
  63. lookup_class = target.get_lookup('exact')
  64. lookup = lookup_class(target.get_col(self.lhs.alias, source), val)
  65. value_constraint.add(lookup, AND)
  66. root_constraint.add(value_constraint, OR)
  67. else:
  68. root_constraint.add(
  69. SubqueryConstraint(
  70. self.lhs.alias, [target.column for target in self.lhs.targets],
  71. [source.name for source in self.lhs.sources], self.rhs),
  72. AND)
  73. return root_constraint.as_sql(compiler, connection)
  74. else:
  75. if (not getattr(self.rhs, 'has_select_fields', True) and
  76. not getattr(self.lhs.field.target_field, 'primary_key', False)):
  77. self.rhs.clear_select_clause()
  78. if (getattr(self.lhs.output_field, 'primary_key', False) and
  79. self.lhs.output_field.model == self.rhs.model):
  80. # A case like Restaurant.objects.filter(place__in=restaurant_qs),
  81. # where place is a OneToOneField and the primary key of
  82. # Restaurant.
  83. target_field = self.lhs.field.name
  84. else:
  85. target_field = self.lhs.field.target_field.name
  86. self.rhs.add_fields([target_field], True)
  87. return super().as_sql(compiler, connection)
  88. class RelatedLookupMixin:
  89. def get_prep_lookup(self):
  90. if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value():
  91. # If we get here, we are dealing with single-column relations.
  92. self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
  93. # We need to run the related field's get_prep_value(). Consider case
  94. # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
  95. # doesn't have validation for non-integers, so we must run validation
  96. # using the target field.
  97. if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_path_info'):
  98. # Get the target field. We can safely assume there is only one
  99. # as we don't get to the direct value branch otherwise.
  100. target_field = self.lhs.output_field.get_path_info()[-1].target_fields[-1]
  101. self.rhs = target_field.get_prep_value(self.rhs)
  102. return super().get_prep_lookup()
  103. def as_sql(self, compiler, connection):
  104. if isinstance(self.lhs, MultiColSource):
  105. assert self.rhs_is_direct_value()
  106. self.rhs = get_normalized_value(self.rhs, self.lhs)
  107. from django.db.models.sql.where import WhereNode, AND
  108. root_constraint = WhereNode()
  109. for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs):
  110. lookup_class = target.get_lookup(self.lookup_name)
  111. root_constraint.add(
  112. lookup_class(target.get_col(self.lhs.alias, source), val), AND)
  113. return root_constraint.as_sql(compiler, connection)
  114. return super().as_sql(compiler, connection)
  115. class RelatedExact(RelatedLookupMixin, Exact):
  116. pass
  117. class RelatedLessThan(RelatedLookupMixin, LessThan):
  118. pass
  119. class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
  120. pass
  121. class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
  122. pass
  123. class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
  124. pass
  125. class RelatedIsNull(RelatedLookupMixin, IsNull):
  126. pass