You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

functions.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. import warnings
  2. from decimal import Decimal
  3. from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
  4. from django.contrib.gis.db.models.sql import AreaField, DistanceField
  5. from django.contrib.gis.geos import GEOSGeometry
  6. from django.core.exceptions import FieldError
  7. from django.db.models import (
  8. BooleanField, FloatField, IntegerField, TextField, Transform,
  9. )
  10. from django.db.models.expressions import Func, Value
  11. from django.db.models.functions import Cast
  12. from django.db.utils import NotSupportedError
  13. from django.utils.deprecation import RemovedInDjango30Warning
  14. from django.utils.functional import cached_property
  15. NUMERIC_TYPES = (int, float, Decimal)
  16. class GeoFuncMixin:
  17. function = None
  18. geom_param_pos = (0,)
  19. def __init__(self, *expressions, **extra):
  20. super().__init__(*expressions, **extra)
  21. # Ensure that value expressions are geometric.
  22. for pos in self.geom_param_pos:
  23. expr = self.source_expressions[pos]
  24. if not isinstance(expr, Value):
  25. continue
  26. try:
  27. output_field = expr.output_field
  28. except FieldError:
  29. output_field = None
  30. geom = expr.value
  31. if not isinstance(geom, GEOSGeometry) or output_field and not isinstance(output_field, GeometryField):
  32. raise TypeError("%s function requires a geometric argument in position %d." % (self.name, pos + 1))
  33. if not geom.srid and not output_field:
  34. raise ValueError("SRID is required for all geometries.")
  35. if not output_field:
  36. self.source_expressions[pos] = Value(geom, output_field=GeometryField(srid=geom.srid))
  37. @property
  38. def name(self):
  39. return self.__class__.__name__
  40. @cached_property
  41. def geo_field(self):
  42. return self.source_expressions[self.geom_param_pos[0]].field
  43. def as_sql(self, compiler, connection, function=None, **extra_context):
  44. if not self.function and not function:
  45. function = connection.ops.spatial_function_name(self.name)
  46. return super().as_sql(compiler, connection, function=function, **extra_context)
  47. def resolve_expression(self, *args, **kwargs):
  48. res = super().resolve_expression(*args, **kwargs)
  49. # Ensure that expressions are geometric.
  50. source_fields = res.get_source_fields()
  51. for pos in self.geom_param_pos:
  52. field = source_fields[pos]
  53. if not isinstance(field, GeometryField):
  54. raise TypeError(
  55. "%s function requires a GeometryField in position %s, got %s." % (
  56. self.name, pos + 1, type(field).__name__,
  57. )
  58. )
  59. base_srid = res.geo_field.srid
  60. for pos in self.geom_param_pos[1:]:
  61. expr = res.source_expressions[pos]
  62. expr_srid = expr.output_field.srid
  63. if expr_srid != base_srid:
  64. # Automatic SRID conversion so objects are comparable.
  65. res.source_expressions[pos] = Transform(expr, base_srid).resolve_expression(*args, **kwargs)
  66. return res
  67. def _handle_param(self, value, param_name='', check_types=None):
  68. if not hasattr(value, 'resolve_expression'):
  69. if check_types and not isinstance(value, check_types):
  70. raise TypeError(
  71. "The %s parameter has the wrong type: should be %s." % (
  72. param_name, check_types)
  73. )
  74. return value
  75. class GeoFunc(GeoFuncMixin, Func):
  76. pass
  77. class GeomOutputGeoFunc(GeoFunc):
  78. @cached_property
  79. def output_field(self):
  80. return GeometryField(srid=self.geo_field.srid)
  81. class SQLiteDecimalToFloatMixin:
  82. """
  83. By default, Decimal values are converted to str by the SQLite backend, which
  84. is not acceptable by the GIS functions expecting numeric values.
  85. """
  86. def as_sqlite(self, compiler, connection, **extra_context):
  87. for expr in self.get_source_expressions():
  88. if hasattr(expr, 'value') and isinstance(expr.value, Decimal):
  89. expr.value = float(expr.value)
  90. return super().as_sql(compiler, connection, **extra_context)
  91. class OracleToleranceMixin:
  92. tolerance = 0.05
  93. def as_oracle(self, compiler, connection, **extra_context):
  94. tol = self.extra.get('tolerance', self.tolerance)
  95. return self.as_sql(
  96. compiler, connection,
  97. template="%%(function)s(%%(expressions)s, %s)" % tol,
  98. **extra_context
  99. )
  100. class Area(OracleToleranceMixin, GeoFunc):
  101. arity = 1
  102. @cached_property
  103. def output_field(self):
  104. return AreaField(self.geo_field)
  105. def as_sql(self, compiler, connection, **extra_context):
  106. if not connection.features.supports_area_geodetic and self.geo_field.geodetic(connection):
  107. raise NotSupportedError('Area on geodetic coordinate systems not supported.')
  108. return super().as_sql(compiler, connection, **extra_context)
  109. def as_sqlite(self, compiler, connection, **extra_context):
  110. if self.geo_field.geodetic(connection):
  111. extra_context['template'] = '%(function)s(%(expressions)s, %(spheroid)d)'
  112. extra_context['spheroid'] = True
  113. return self.as_sql(compiler, connection, **extra_context)
  114. class Azimuth(GeoFunc):
  115. output_field = FloatField()
  116. arity = 2
  117. geom_param_pos = (0, 1)
  118. class AsGeoJSON(GeoFunc):
  119. output_field = TextField()
  120. def __init__(self, expression, bbox=False, crs=False, precision=8, **extra):
  121. expressions = [expression]
  122. if precision is not None:
  123. expressions.append(self._handle_param(precision, 'precision', int))
  124. options = 0
  125. if crs and bbox:
  126. options = 3
  127. elif bbox:
  128. options = 1
  129. elif crs:
  130. options = 2
  131. if options:
  132. expressions.append(options)
  133. super().__init__(*expressions, **extra)
  134. class AsGML(GeoFunc):
  135. geom_param_pos = (1,)
  136. output_field = TextField()
  137. def __init__(self, expression, version=2, precision=8, **extra):
  138. expressions = [version, expression]
  139. if precision is not None:
  140. expressions.append(self._handle_param(precision, 'precision', int))
  141. super().__init__(*expressions, **extra)
  142. def as_oracle(self, compiler, connection, **extra_context):
  143. source_expressions = self.get_source_expressions()
  144. version = source_expressions[0]
  145. clone = self.copy()
  146. clone.set_source_expressions([source_expressions[1]])
  147. extra_context['function'] = 'SDO_UTIL.TO_GML311GEOMETRY' if version.value == 3 else 'SDO_UTIL.TO_GMLGEOMETRY'
  148. return super(AsGML, clone).as_sql(compiler, connection, **extra_context)
  149. class AsKML(AsGML):
  150. def as_sqlite(self, compiler, connection, **extra_context):
  151. # No version parameter
  152. clone = self.copy()
  153. clone.set_source_expressions(self.get_source_expressions()[1:])
  154. return clone.as_sql(compiler, connection, **extra_context)
  155. class AsSVG(GeoFunc):
  156. output_field = TextField()
  157. def __init__(self, expression, relative=False, precision=8, **extra):
  158. relative = relative if hasattr(relative, 'resolve_expression') else int(relative)
  159. expressions = [
  160. expression,
  161. relative,
  162. self._handle_param(precision, 'precision', int),
  163. ]
  164. super().__init__(*expressions, **extra)
  165. class BoundingCircle(OracleToleranceMixin, GeoFunc):
  166. def __init__(self, expression, num_seg=48, **extra):
  167. super().__init__(expression, num_seg, **extra)
  168. def as_oracle(self, compiler, connection, **extra_context):
  169. clone = self.copy()
  170. clone.set_source_expressions([self.get_source_expressions()[0]])
  171. return super(BoundingCircle, clone).as_oracle(compiler, connection, **extra_context)
  172. class Centroid(OracleToleranceMixin, GeomOutputGeoFunc):
  173. arity = 1
  174. class Difference(OracleToleranceMixin, GeomOutputGeoFunc):
  175. arity = 2
  176. geom_param_pos = (0, 1)
  177. class DistanceResultMixin:
  178. @cached_property
  179. def output_field(self):
  180. return DistanceField(self.geo_field)
  181. def source_is_geography(self):
  182. return self.geo_field.geography and self.geo_field.srid == 4326
  183. class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
  184. geom_param_pos = (0, 1)
  185. spheroid = None
  186. def __init__(self, expr1, expr2, spheroid=None, **extra):
  187. expressions = [expr1, expr2]
  188. if spheroid is not None:
  189. self.spheroid = self._handle_param(spheroid, 'spheroid', bool)
  190. super().__init__(*expressions, **extra)
  191. def as_postgresql(self, compiler, connection, **extra_context):
  192. clone = self.copy()
  193. function = None
  194. expr2 = clone.source_expressions[1]
  195. geography = self.source_is_geography()
  196. if expr2.output_field.geography != geography:
  197. if isinstance(expr2, Value):
  198. expr2.output_field.geography = geography
  199. else:
  200. clone.source_expressions[1] = Cast(
  201. expr2,
  202. GeometryField(srid=expr2.output_field.srid, geography=geography),
  203. )
  204. if not geography and self.geo_field.geodetic(connection):
  205. # Geometry fields with geodetic (lon/lat) coordinates need special distance functions
  206. if self.spheroid:
  207. # DistanceSpheroid is more accurate and resource intensive than DistanceSphere
  208. function = connection.ops.spatial_function_name('DistanceSpheroid')
  209. # Replace boolean param by the real spheroid of the base field
  210. clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
  211. else:
  212. function = connection.ops.spatial_function_name('DistanceSphere')
  213. return super(Distance, clone).as_sql(compiler, connection, function=function, **extra_context)
  214. def as_sqlite(self, compiler, connection, **extra_context):
  215. if self.geo_field.geodetic(connection):
  216. # SpatiaLite returns NULL instead of zero on geodetic coordinates
  217. extra_context['template'] = 'COALESCE(%(function)s(%(expressions)s, %(spheroid)s), 0)'
  218. extra_context['spheroid'] = int(bool(self.spheroid))
  219. return super().as_sql(compiler, connection, **extra_context)
  220. class Envelope(GeomOutputGeoFunc):
  221. arity = 1
  222. class ForcePolygonCW(GeomOutputGeoFunc):
  223. arity = 1
  224. class ForceRHR(GeomOutputGeoFunc):
  225. arity = 1
  226. def __init__(self, *args, **kwargs):
  227. warnings.warn(
  228. 'ForceRHR is deprecated in favor of ForcePolygonCW.',
  229. RemovedInDjango30Warning, stacklevel=2,
  230. )
  231. super().__init__(*args, **kwargs)
  232. class GeoHash(GeoFunc):
  233. output_field = TextField()
  234. def __init__(self, expression, precision=None, **extra):
  235. expressions = [expression]
  236. if precision is not None:
  237. expressions.append(self._handle_param(precision, 'precision', int))
  238. super().__init__(*expressions, **extra)
  239. def as_mysql(self, compiler, connection, **extra_context):
  240. clone = self.copy()
  241. # If no precision is provided, set it to the maximum.
  242. if len(clone.source_expressions) < 2:
  243. clone.source_expressions.append(Value(100))
  244. return clone.as_sql(compiler, connection, **extra_context)
  245. class Intersection(OracleToleranceMixin, GeomOutputGeoFunc):
  246. arity = 2
  247. geom_param_pos = (0, 1)
  248. @BaseSpatialField.register_lookup
  249. class IsValid(OracleToleranceMixin, GeoFuncMixin, Transform):
  250. lookup_name = 'isvalid'
  251. output_field = BooleanField()
  252. def as_oracle(self, compiler, connection, **extra_context):
  253. sql, params = super().as_oracle(compiler, connection, **extra_context)
  254. return "CASE %s WHEN 'TRUE' THEN 1 ELSE 0 END" % sql, params
  255. class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
  256. def __init__(self, expr1, spheroid=True, **extra):
  257. self.spheroid = spheroid
  258. super().__init__(expr1, **extra)
  259. def as_sql(self, compiler, connection, **extra_context):
  260. if self.geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
  261. raise NotSupportedError("This backend doesn't support Length on geodetic fields")
  262. return super().as_sql(compiler, connection, **extra_context)
  263. def as_postgresql(self, compiler, connection, **extra_context):
  264. clone = self.copy()
  265. function = None
  266. if self.source_is_geography():
  267. clone.source_expressions.append(Value(self.spheroid))
  268. elif self.geo_field.geodetic(connection):
  269. # Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
  270. function = connection.ops.spatial_function_name('LengthSpheroid')
  271. clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
  272. else:
  273. dim = min(f.dim for f in self.get_source_fields() if f)
  274. if dim > 2:
  275. function = connection.ops.length3d
  276. return super(Length, clone).as_sql(compiler, connection, function=function, **extra_context)
  277. def as_sqlite(self, compiler, connection, **extra_context):
  278. function = None
  279. if self.geo_field.geodetic(connection):
  280. function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength'
  281. return super().as_sql(compiler, connection, function=function, **extra_context)
  282. class LineLocatePoint(GeoFunc):
  283. output_field = FloatField()
  284. arity = 2
  285. geom_param_pos = (0, 1)
  286. class MakeValid(GeoFunc):
  287. pass
  288. class MemSize(GeoFunc):
  289. output_field = IntegerField()
  290. arity = 1
  291. class NumGeometries(GeoFunc):
  292. output_field = IntegerField()
  293. arity = 1
  294. class NumPoints(GeoFunc):
  295. output_field = IntegerField()
  296. arity = 1
  297. class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
  298. arity = 1
  299. def as_postgresql(self, compiler, connection, **extra_context):
  300. function = None
  301. if self.geo_field.geodetic(connection) and not self.source_is_geography():
  302. raise NotSupportedError("ST_Perimeter cannot use a non-projected non-geography field.")
  303. dim = min(f.dim for f in self.get_source_fields())
  304. if dim > 2:
  305. function = connection.ops.perimeter3d
  306. return super().as_sql(compiler, connection, function=function, **extra_context)
  307. def as_sqlite(self, compiler, connection, **extra_context):
  308. if self.geo_field.geodetic(connection):
  309. raise NotSupportedError("Perimeter cannot use a non-projected field.")
  310. return super().as_sql(compiler, connection, **extra_context)
  311. class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc):
  312. arity = 1
  313. class Reverse(GeoFunc):
  314. arity = 1
  315. class Scale(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc):
  316. def __init__(self, expression, x, y, z=0.0, **extra):
  317. expressions = [
  318. expression,
  319. self._handle_param(x, 'x', NUMERIC_TYPES),
  320. self._handle_param(y, 'y', NUMERIC_TYPES),
  321. ]
  322. if z != 0.0:
  323. expressions.append(self._handle_param(z, 'z', NUMERIC_TYPES))
  324. super().__init__(*expressions, **extra)
  325. class SnapToGrid(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc):
  326. def __init__(self, expression, *args, **extra):
  327. nargs = len(args)
  328. expressions = [expression]
  329. if nargs in (1, 2):
  330. expressions.extend(
  331. [self._handle_param(arg, '', NUMERIC_TYPES) for arg in args]
  332. )
  333. elif nargs == 4:
  334. # Reverse origin and size param ordering
  335. expressions += [
  336. *(self._handle_param(arg, '', NUMERIC_TYPES) for arg in args[2:]),
  337. *(self._handle_param(arg, '', NUMERIC_TYPES) for arg in args[0:2]),
  338. ]
  339. else:
  340. raise ValueError('Must provide 1, 2, or 4 arguments to `SnapToGrid`.')
  341. super().__init__(*expressions, **extra)
  342. class SymDifference(OracleToleranceMixin, GeomOutputGeoFunc):
  343. arity = 2
  344. geom_param_pos = (0, 1)
  345. class Transform(GeomOutputGeoFunc):
  346. def __init__(self, expression, srid, **extra):
  347. expressions = [
  348. expression,
  349. self._handle_param(srid, 'srid', int),
  350. ]
  351. if 'output_field' not in extra:
  352. extra['output_field'] = GeometryField(srid=srid)
  353. super().__init__(*expressions, **extra)
  354. class Translate(Scale):
  355. def as_sqlite(self, compiler, connection, **extra_context):
  356. clone = self.copy()
  357. if len(self.source_expressions) < 4:
  358. # Always provide the z parameter for ST_Translate
  359. clone.source_expressions.append(Value(0))
  360. return super(Translate, clone).as_sqlite(compiler, connection, **extra_context)
  361. class Union(OracleToleranceMixin, GeomOutputGeoFunc):
  362. arity = 2
  363. geom_param_pos = (0, 1)