You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long. 8.8KB

  1. from django.db.models.expressions import Func, Value
  2. from django.db.models.fields import IntegerField
  3. from django.db.models.functions import Coalesce
  4. from django.db.models.lookups import Transform
  5. class BytesToCharFieldConversionMixin:
  6. """
  7. Convert CharField results from bytes to str.
  8. MySQL returns long data types (bytes) instead of chars when it can't
  9. determine the length of the result string. For example:
  10. LPAD(column1, CHAR_LENGTH(column2), ' ')
  11. returns the LONGTEXT (bytes) instead of VARCHAR.
  12. """
  13. def convert_value(self, value, expression, connection):
  14. if connection.features.db_functions_convert_bytes_to_str:
  15. if self.output_field.get_internal_type() == 'CharField' and isinstance(value, bytes):
  16. return value.decode()
  17. return super().convert_value(value, expression, connection)
  18. class Chr(Transform):
  19. function = 'CHR'
  20. lookup_name = 'chr'
  21. def as_mysql(self, compiler, connection, **extra_context):
  22. return super().as_sql(
  23. compiler, connection, function='CHAR',
  24. template='%(function)s(%(expressions)s USING utf16)',
  25. **extra_context
  26. )
  27. def as_oracle(self, compiler, connection, **extra_context):
  28. return super().as_sql(
  29. compiler, connection,
  30. template='%(function)s(%(expressions)s USING NCHAR_CS)',
  31. **extra_context
  32. )
  33. def as_sqlite(self, compiler, connection, **extra_context):
  34. return super().as_sql(compiler, connection, function='CHAR', **extra_context)
  35. class ConcatPair(Func):
  36. """
  37. Concatenate two arguments together. This is used by `Concat` because not
  38. all backend databases support more than two arguments.
  39. """
  40. function = 'CONCAT'
  41. def as_sqlite(self, compiler, connection, **extra_context):
  42. coalesced = self.coalesce()
  43. return super(ConcatPair, coalesced).as_sql(
  44. compiler, connection, template='%(expressions)s', arg_joiner=' || ',
  45. **extra_context
  46. )
  47. def as_mysql(self, compiler, connection, **extra_context):
  48. # Use CONCAT_WS with an empty separator so that NULLs are ignored.
  49. return super().as_sql(
  50. compiler, connection, function='CONCAT_WS',
  51. template="%(function)s('', %(expressions)s)",
  52. **extra_context
  53. )
  54. def coalesce(self):
  55. # null on either side results in null for expression, wrap with coalesce
  56. c = self.copy()
  57. c.set_source_expressions([
  58. Coalesce(expression, Value('')) for expression in c.get_source_expressions()
  59. ])
  60. return c
  61. class Concat(Func):
  62. """
  63. Concatenate text fields together. Backends that result in an entire
  64. null expression when any arguments are null will wrap each argument in
  65. coalesce functions to ensure a non-null result.
  66. """
  67. function = None
  68. template = "%(expressions)s"
  69. def __init__(self, *expressions, **extra):
  70. if len(expressions) < 2:
  71. raise ValueError('Concat must take at least two expressions')
  72. paired = self._paired(expressions)
  73. super().__init__(paired, **extra)
  74. def _paired(self, expressions):
  75. # wrap pairs of expressions in successive concat functions
  76. # exp = [a, b, c, d]
  77. # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
  78. if len(expressions) == 2:
  79. return ConcatPair(*expressions)
  80. return ConcatPair(expressions[0], self._paired(expressions[1:]))
  81. class Left(Func):
  82. function = 'LEFT'
  83. arity = 2
  84. def __init__(self, expression, length, **extra):
  85. """
  86. expression: the name of a field, or an expression returning a string
  87. length: the number of characters to return from the start of the string
  88. """
  89. if not hasattr(length, 'resolve_expression'):
  90. if length < 1:
  91. raise ValueError("'length' must be greater than 0.")
  92. super().__init__(expression, length, **extra)
  93. def get_substr(self):
  94. return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
  95. def as_oracle(self, compiler, connection, **extra_context):
  96. return self.get_substr().as_oracle(compiler, connection, **extra_context)
  97. def as_sqlite(self, compiler, connection, **extra_context):
  98. return self.get_substr().as_sqlite(compiler, connection, **extra_context)
  99. class Length(Transform):
  100. """Return the number of characters in the expression."""
  101. function = 'LENGTH'
  102. lookup_name = 'length'
  103. output_field = IntegerField()
  104. def as_mysql(self, compiler, connection, **extra_context):
  105. return super().as_sql(compiler, connection, function='CHAR_LENGTH', **extra_context)
  106. class Lower(Transform):
  107. function = 'LOWER'
  108. lookup_name = 'lower'
  109. class LPad(BytesToCharFieldConversionMixin, Func):
  110. function = 'LPAD'
  111. def __init__(self, expression, length, fill_text=Value(' '), **extra):
  112. if not hasattr(length, 'resolve_expression') and length is not None and length < 0:
  113. raise ValueError("'length' must be greater or equal to 0.")
  114. super().__init__(expression, length, fill_text, **extra)
  115. class LTrim(Transform):
  116. function = 'LTRIM'
  117. lookup_name = 'ltrim'
  118. class Ord(Transform):
  119. function = 'ASCII'
  120. lookup_name = 'ord'
  121. output_field = IntegerField()
  122. def as_mysql(self, compiler, connection, **extra_context):
  123. return super().as_sql(compiler, connection, function='ORD', **extra_context)
  124. def as_sqlite(self, compiler, connection, **extra_context):
  125. return super().as_sql(compiler, connection, function='UNICODE', **extra_context)
  126. class Repeat(BytesToCharFieldConversionMixin, Func):
  127. function = 'REPEAT'
  128. def __init__(self, expression, number, **extra):
  129. if not hasattr(number, 'resolve_expression') and number is not None and number < 0:
  130. raise ValueError("'number' must be greater or equal to 0.")
  131. super().__init__(expression, number, **extra)
  132. def as_oracle(self, compiler, connection, **extra_context):
  133. expression, number = self.source_expressions
  134. length = None if number is None else Length(expression) * number
  135. rpad = RPad(expression, length, expression)
  136. return rpad.as_sql(compiler, connection, **extra_context)
  137. class Replace(Func):
  138. function = 'REPLACE'
  139. def __init__(self, expression, text, replacement=Value(''), **extra):
  140. super().__init__(expression, text, replacement, **extra)
  141. class Reverse(Transform):
  142. function = 'REVERSE'
  143. lookup_name = 'reverse'
  144. def as_oracle(self, compiler, connection, **extra_context):
  145. # REVERSE in Oracle is undocumented and doesn't support multi-byte
  146. # strings. Use a special subquery instead.
  147. return super().as_sql(
  148. compiler, connection,
  149. template=(
  151. '(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s '
  152. 'FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) '
  153. 'GROUP BY %(expressions)s)'
  154. ),
  155. **extra_context
  156. )
  157. class Right(Left):
  158. function = 'RIGHT'
  159. def get_substr(self):
  160. return Substr(self.source_expressions[0], self.source_expressions[1] * Value(-1))
  161. class RPad(LPad):
  162. function = 'RPAD'
  163. class RTrim(Transform):
  164. function = 'RTRIM'
  165. lookup_name = 'rtrim'
  166. class StrIndex(Func):
  167. """
  168. Return a positive integer corresponding to the 1-indexed position of the
  169. first occurrence of a substring inside another string, or 0 if the
  170. substring is not found.
  171. """
  172. function = 'INSTR'
  173. arity = 2
  174. output_field = IntegerField()
  175. def as_postgresql(self, compiler, connection, **extra_context):
  176. return super().as_sql(compiler, connection, function='STRPOS', **extra_context)
  177. class Substr(Func):
  178. function = 'SUBSTRING'
  179. def __init__(self, expression, pos, length=None, **extra):
  180. """
  181. expression: the name of a field, or an expression returning a string
  182. pos: an integer > 0, or an expression returning an integer
  183. length: an optional number of characters to return
  184. """
  185. if not hasattr(pos, 'resolve_expression'):
  186. if pos < 1:
  187. raise ValueError("'pos' must be greater than 0")
  188. expressions = [expression, pos]
  189. if length is not None:
  190. expressions.append(length)
  191. super().__init__(*expressions, **extra)
  192. def as_sqlite(self, compiler, connection, **extra_context):
  193. return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
  194. def as_oracle(self, compiler, connection, **extra_context):
  195. return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
  196. class Trim(Transform):
  197. function = 'TRIM'
  198. lookup_name = 'trim'
  199. class Upper(Transform):
  200. function = 'UPPER'
  201. lookup_name = 'upper'