You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lookups.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. import re
  2. from django.contrib.gis.db.models.fields import BaseSpatialField
  3. from django.db.models.expressions import Expression
  4. from django.db.models.lookups import Lookup, Transform
  5. from django.db.models.sql.query import Query
  6. class RasterBandTransform(Transform):
  7. def as_sql(self, compiler, connection):
  8. return compiler.compile(self.lhs)
  9. class GISLookup(Lookup):
  10. sql_template = None
  11. transform_func = None
  12. distance = False
  13. band_rhs = None
  14. band_lhs = None
  15. def __init__(self, lhs, rhs):
  16. rhs, *self.rhs_params = rhs if isinstance(rhs, (list, tuple)) else [rhs]
  17. super().__init__(lhs, rhs)
  18. self.template_params = {}
  19. self.process_rhs_params()
  20. def process_rhs_params(self):
  21. if self.rhs_params:
  22. # Check if a band index was passed in the query argument.
  23. if len(self.rhs_params) == (2 if self.lookup_name == 'relate' else 1):
  24. self.process_band_indices()
  25. elif len(self.rhs_params) > 1:
  26. raise ValueError('Tuple too long for lookup %s.' % self.lookup_name)
  27. elif isinstance(self.lhs, RasterBandTransform):
  28. self.process_band_indices(only_lhs=True)
  29. def process_band_indices(self, only_lhs=False):
  30. """
  31. Extract the lhs band index from the band transform class and the rhs
  32. band index from the input tuple.
  33. """
  34. # PostGIS band indices are 1-based, so the band index needs to be
  35. # increased to be consistent with the GDALRaster band indices.
  36. if only_lhs:
  37. self.band_rhs = 1
  38. self.band_lhs = self.lhs.band_index + 1
  39. return
  40. if isinstance(self.lhs, RasterBandTransform):
  41. self.band_lhs = self.lhs.band_index + 1
  42. else:
  43. self.band_lhs = 1
  44. self.band_rhs, *self.rhs_params = self.rhs_params
  45. def get_db_prep_lookup(self, value, connection):
  46. # get_db_prep_lookup is called by process_rhs from super class
  47. return ('%s', [connection.ops.Adapter(value)])
  48. def process_rhs(self, compiler, connection):
  49. if isinstance(self.rhs, Query):
  50. # If rhs is some Query, don't touch it.
  51. return super().process_rhs(compiler, connection)
  52. if isinstance(self.rhs, Expression):
  53. self.rhs = self.rhs.resolve_expression(compiler.query)
  54. rhs, rhs_params = super().process_rhs(compiler, connection)
  55. placeholder = connection.ops.get_geom_placeholder(self.lhs.output_field, self.rhs, compiler)
  56. return placeholder % rhs, rhs_params
  57. def get_rhs_op(self, connection, rhs):
  58. # Unlike BuiltinLookup, the GIS get_rhs_op() implementation should return
  59. # an object (SpatialOperator) with an as_sql() method to allow for more
  60. # complex computations (where the lhs part can be mixed in).
  61. return connection.ops.gis_operators[self.lookup_name]
  62. def as_sql(self, compiler, connection):
  63. lhs_sql, sql_params = self.process_lhs(compiler, connection)
  64. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  65. sql_params.extend(rhs_params)
  66. template_params = {'lhs': lhs_sql, 'rhs': rhs_sql, 'value': '%s', **self.template_params}
  67. rhs_op = self.get_rhs_op(connection, rhs_sql)
  68. return rhs_op.as_sql(connection, self, template_params, sql_params)
  69. # ------------------
  70. # Geometry operators
  71. # ------------------
  72. @BaseSpatialField.register_lookup
  73. class OverlapsLeftLookup(GISLookup):
  74. """
  75. The overlaps_left operator returns true if A's bounding box overlaps or is to the
  76. left of B's bounding box.
  77. """
  78. lookup_name = 'overlaps_left'
  79. @BaseSpatialField.register_lookup
  80. class OverlapsRightLookup(GISLookup):
  81. """
  82. The 'overlaps_right' operator returns true if A's bounding box overlaps or is to the
  83. right of B's bounding box.
  84. """
  85. lookup_name = 'overlaps_right'
  86. @BaseSpatialField.register_lookup
  87. class OverlapsBelowLookup(GISLookup):
  88. """
  89. The 'overlaps_below' operator returns true if A's bounding box overlaps or is below
  90. B's bounding box.
  91. """
  92. lookup_name = 'overlaps_below'
  93. @BaseSpatialField.register_lookup
  94. class OverlapsAboveLookup(GISLookup):
  95. """
  96. The 'overlaps_above' operator returns true if A's bounding box overlaps or is above
  97. B's bounding box.
  98. """
  99. lookup_name = 'overlaps_above'
  100. @BaseSpatialField.register_lookup
  101. class LeftLookup(GISLookup):
  102. """
  103. The 'left' operator returns true if A's bounding box is strictly to the left
  104. of B's bounding box.
  105. """
  106. lookup_name = 'left'
  107. @BaseSpatialField.register_lookup
  108. class RightLookup(GISLookup):
  109. """
  110. The 'right' operator returns true if A's bounding box is strictly to the right
  111. of B's bounding box.
  112. """
  113. lookup_name = 'right'
  114. @BaseSpatialField.register_lookup
  115. class StrictlyBelowLookup(GISLookup):
  116. """
  117. The 'strictly_below' operator returns true if A's bounding box is strictly below B's
  118. bounding box.
  119. """
  120. lookup_name = 'strictly_below'
  121. @BaseSpatialField.register_lookup
  122. class StrictlyAboveLookup(GISLookup):
  123. """
  124. The 'strictly_above' operator returns true if A's bounding box is strictly above B's
  125. bounding box.
  126. """
  127. lookup_name = 'strictly_above'
  128. @BaseSpatialField.register_lookup
  129. class SameAsLookup(GISLookup):
  130. """
  131. The "~=" operator is the "same as" operator. It tests actual geometric
  132. equality of two features. So if A and B are the same feature,
  133. vertex-by-vertex, the operator returns true.
  134. """
  135. lookup_name = 'same_as'
  136. BaseSpatialField.register_lookup(SameAsLookup, 'exact')
  137. @BaseSpatialField.register_lookup
  138. class BBContainsLookup(GISLookup):
  139. """
  140. The 'bbcontains' operator returns true if A's bounding box completely contains
  141. by B's bounding box.
  142. """
  143. lookup_name = 'bbcontains'
  144. @BaseSpatialField.register_lookup
  145. class BBOverlapsLookup(GISLookup):
  146. """
  147. The 'bboverlaps' operator returns true if A's bounding box overlaps B's bounding box.
  148. """
  149. lookup_name = 'bboverlaps'
  150. @BaseSpatialField.register_lookup
  151. class ContainedLookup(GISLookup):
  152. """
  153. The 'contained' operator returns true if A's bounding box is completely contained
  154. by B's bounding box.
  155. """
  156. lookup_name = 'contained'
  157. # ------------------
  158. # Geometry functions
  159. # ------------------
  160. @BaseSpatialField.register_lookup
  161. class ContainsLookup(GISLookup):
  162. lookup_name = 'contains'
  163. @BaseSpatialField.register_lookup
  164. class ContainsProperlyLookup(GISLookup):
  165. lookup_name = 'contains_properly'
  166. @BaseSpatialField.register_lookup
  167. class CoveredByLookup(GISLookup):
  168. lookup_name = 'coveredby'
  169. @BaseSpatialField.register_lookup
  170. class CoversLookup(GISLookup):
  171. lookup_name = 'covers'
  172. @BaseSpatialField.register_lookup
  173. class CrossesLookup(GISLookup):
  174. lookup_name = 'crosses'
  175. @BaseSpatialField.register_lookup
  176. class DisjointLookup(GISLookup):
  177. lookup_name = 'disjoint'
  178. @BaseSpatialField.register_lookup
  179. class EqualsLookup(GISLookup):
  180. lookup_name = 'equals'
  181. @BaseSpatialField.register_lookup
  182. class IntersectsLookup(GISLookup):
  183. lookup_name = 'intersects'
  184. @BaseSpatialField.register_lookup
  185. class OverlapsLookup(GISLookup):
  186. lookup_name = 'overlaps'
  187. @BaseSpatialField.register_lookup
  188. class RelateLookup(GISLookup):
  189. lookup_name = 'relate'
  190. sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s)'
  191. pattern_regex = re.compile(r'^[012TF\*]{9}$')
  192. def process_rhs(self, compiler, connection):
  193. # Check the pattern argument
  194. pattern = self.rhs_params[0]
  195. backend_op = connection.ops.gis_operators[self.lookup_name]
  196. if hasattr(backend_op, 'check_relate_argument'):
  197. backend_op.check_relate_argument(pattern)
  198. elif not isinstance(pattern, str) or not self.pattern_regex.match(pattern):
  199. raise ValueError('Invalid intersection matrix pattern "%s".' % pattern)
  200. sql, params = super().process_rhs(compiler, connection)
  201. return sql, params + [pattern]
  202. @BaseSpatialField.register_lookup
  203. class TouchesLookup(GISLookup):
  204. lookup_name = 'touches'
  205. @BaseSpatialField.register_lookup
  206. class WithinLookup(GISLookup):
  207. lookup_name = 'within'
  208. class DistanceLookupBase(GISLookup):
  209. distance = True
  210. sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
  211. def process_rhs_params(self):
  212. if not 1 <= len(self.rhs_params) <= 3:
  213. raise ValueError("2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name)
  214. elif len(self.rhs_params) == 3 and self.rhs_params[2] != 'spheroid':
  215. raise ValueError("For 4-element tuples the last argument must be the 'spheroid' directive.")
  216. # Check if the second parameter is a band index.
  217. if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid':
  218. self.process_band_indices()
  219. def process_distance(self, compiler, connection):
  220. dist_param = self.rhs_params[0]
  221. return (
  222. compiler.compile(dist_param.resolve_expression(compiler.query))
  223. if hasattr(dist_param, 'resolve_expression') else
  224. ('%s', connection.ops.get_distance(self.lhs.output_field, self.rhs_params, self.lookup_name))
  225. )
  226. @BaseSpatialField.register_lookup
  227. class DWithinLookup(DistanceLookupBase):
  228. lookup_name = 'dwithin'
  229. sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s)'
  230. def process_rhs(self, compiler, connection):
  231. dist_sql, dist_params = self.process_distance(compiler, connection)
  232. self.template_params['value'] = dist_sql
  233. rhs_sql, params = super().process_rhs(compiler, connection)
  234. return rhs_sql, params + dist_params
  235. class DistanceLookupFromFunction(DistanceLookupBase):
  236. def as_sql(self, compiler, connection):
  237. spheroid = (len(self.rhs_params) == 2 and self.rhs_params[-1] == 'spheroid') or None
  238. distance_expr = connection.ops.distance_expr_for_lookup(self.lhs, self.rhs, spheroid=spheroid)
  239. sql, params = compiler.compile(distance_expr.resolve_expression(compiler.query))
  240. dist_sql, dist_params = self.process_distance(compiler, connection)
  241. return (
  242. '%(func)s %(op)s %(dist)s' % {'func': sql, 'op': self.op, 'dist': dist_sql},
  243. params + dist_params,
  244. )
  245. @BaseSpatialField.register_lookup
  246. class DistanceGTLookup(DistanceLookupFromFunction):
  247. lookup_name = 'distance_gt'
  248. op = '>'
  249. @BaseSpatialField.register_lookup
  250. class DistanceGTELookup(DistanceLookupFromFunction):
  251. lookup_name = 'distance_gte'
  252. op = '>='
  253. @BaseSpatialField.register_lookup
  254. class DistanceLTLookup(DistanceLookupFromFunction):
  255. lookup_name = 'distance_lt'
  256. op = '<'
  257. @BaseSpatialField.register_lookup
  258. class DistanceLTELookup(DistanceLookupFromFunction):
  259. lookup_name = 'distance_lte'
  260. op = '<='