Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

text.py 7.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. from django.db.models import Func, IntegerField, Transform, Value, fields
  2. from django.db.models.functions import Coalesce
  3. class BytesToCharFieldConversionMixin:
  4. """
  5. Convert CharField results from bytes to str.
  6. MySQL returns long data types (bytes) instead of chars when it can't
  7. determine the length of the result string. For example:
  8. LPAD(column1, CHAR_LENGTH(column2), ' ')
  9. returns the LONGTEXT (bytes) instead of VARCHAR.
  10. """
  11. def convert_value(self, value, expression, connection):
  12. if connection.features.db_functions_convert_bytes_to_str:
  13. if self.output_field.get_internal_type() == 'CharField' and isinstance(value, bytes):
  14. return value.decode()
  15. return super().convert_value(value, expression, connection)
  16. class Chr(Transform):
  17. function = 'CHR'
  18. lookup_name = 'chr'
  19. def as_mysql(self, compiler, connection):
  20. return super().as_sql(
  21. compiler, connection, function='CHAR', template='%(function)s(%(expressions)s USING utf16)'
  22. )
  23. def as_oracle(self, compiler, connection):
  24. return super().as_sql(compiler, connection, template='%(function)s(%(expressions)s USING NCHAR_CS)')
  25. def as_sqlite(self, compiler, connection, **extra_context):
  26. return super().as_sql(compiler, connection, function='CHAR', **extra_context)
  27. class ConcatPair(Func):
  28. """
  29. Concatenate two arguments together. This is used by `Concat` because not
  30. all backend databases support more than two arguments.
  31. """
  32. function = 'CONCAT'
  33. def as_sqlite(self, compiler, connection):
  34. coalesced = self.coalesce()
  35. return super(ConcatPair, coalesced).as_sql(
  36. compiler, connection, template='%(expressions)s', arg_joiner=' || '
  37. )
  38. def as_mysql(self, compiler, connection):
  39. # Use CONCAT_WS with an empty separator so that NULLs are ignored.
  40. return super().as_sql(
  41. compiler, connection, function='CONCAT_WS', template="%(function)s('', %(expressions)s)"
  42. )
  43. def coalesce(self):
  44. # null on either side results in null for expression, wrap with coalesce
  45. c = self.copy()
  46. expressions = [
  47. Coalesce(expression, Value('')) for expression in c.get_source_expressions()
  48. ]
  49. c.set_source_expressions(expressions)
  50. return c
  51. class Concat(Func):
  52. """
  53. Concatenate text fields together. Backends that result in an entire
  54. null expression when any arguments are null will wrap each argument in
  55. coalesce functions to ensure a non-null result.
  56. """
  57. function = None
  58. template = "%(expressions)s"
  59. def __init__(self, *expressions, **extra):
  60. if len(expressions) < 2:
  61. raise ValueError('Concat must take at least two expressions')
  62. paired = self._paired(expressions)
  63. super().__init__(paired, **extra)
  64. def _paired(self, expressions):
  65. # wrap pairs of expressions in successive concat functions
  66. # exp = [a, b, c, d]
  67. # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
  68. if len(expressions) == 2:
  69. return ConcatPair(*expressions)
  70. return ConcatPair(expressions[0], self._paired(expressions[1:]))
  71. class Left(Func):
  72. function = 'LEFT'
  73. arity = 2
  74. def __init__(self, expression, length, **extra):
  75. """
  76. expression: the name of a field, or an expression returning a string
  77. length: the number of characters to return from the start of the string
  78. """
  79. if not hasattr(length, 'resolve_expression'):
  80. if length < 1:
  81. raise ValueError("'length' must be greater than 0.")
  82. super().__init__(expression, length, **extra)
  83. def get_substr(self):
  84. return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
  85. def use_substr(self, compiler, connection, **extra_context):
  86. return self.get_substr().as_oracle(compiler, connection, **extra_context)
  87. as_oracle = use_substr
  88. as_sqlite = use_substr
  89. class Length(Transform):
  90. """Return the number of characters in the expression."""
  91. function = 'LENGTH'
  92. lookup_name = 'length'
  93. output_field = fields.IntegerField()
  94. def as_mysql(self, compiler, connection):
  95. return super().as_sql(compiler, connection, function='CHAR_LENGTH')
  96. class Lower(Transform):
  97. function = 'LOWER'
  98. lookup_name = 'lower'
  99. class LPad(BytesToCharFieldConversionMixin, Func):
  100. function = 'LPAD'
  101. def __init__(self, expression, length, fill_text=Value(' '), **extra):
  102. if not hasattr(length, 'resolve_expression') and length < 0:
  103. raise ValueError("'length' must be greater or equal to 0.")
  104. super().__init__(expression, length, fill_text, **extra)
  105. class LTrim(Transform):
  106. function = 'LTRIM'
  107. lookup_name = 'ltrim'
  108. class Ord(Transform):
  109. function = 'ASCII'
  110. lookup_name = 'ord'
  111. output_field = IntegerField()
  112. def as_mysql(self, compiler, connection, **extra_context):
  113. return super().as_sql(compiler, connection, function='ORD', **extra_context)
  114. def as_sqlite(self, compiler, connection, **extra_context):
  115. return super().as_sql(compiler, connection, function='UNICODE', **extra_context)
  116. class Repeat(BytesToCharFieldConversionMixin, Func):
  117. function = 'REPEAT'
  118. def __init__(self, expression, number, **extra):
  119. if not hasattr(number, 'resolve_expression') and number < 0:
  120. raise ValueError("'number' must be greater or equal to 0.")
  121. super().__init__(expression, number, **extra)
  122. def as_oracle(self, compiler, connection, **extra_context):
  123. expression, number = self.source_expressions
  124. rpad = RPad(expression, Length(expression) * number, expression)
  125. return rpad.as_sql(compiler, connection, **extra_context)
  126. class Replace(Func):
  127. function = 'REPLACE'
  128. def __init__(self, expression, text, replacement=Value(''), **extra):
  129. super().__init__(expression, text, replacement, **extra)
  130. class Right(Left):
  131. function = 'RIGHT'
  132. def get_substr(self):
  133. return Substr(self.source_expressions[0], self.source_expressions[1] * Value(-1))
  134. class RPad(LPad):
  135. function = 'RPAD'
  136. class RTrim(Transform):
  137. function = 'RTRIM'
  138. lookup_name = 'rtrim'
  139. class StrIndex(Func):
  140. """
  141. Return a positive integer corresponding to the 1-indexed position of the
  142. first occurrence of a substring inside another string, or 0 if the
  143. substring is not found.
  144. """
  145. function = 'INSTR'
  146. arity = 2
  147. output_field = fields.IntegerField()
  148. def as_postgresql(self, compiler, connection):
  149. return super().as_sql(compiler, connection, function='STRPOS')
  150. class Substr(Func):
  151. function = 'SUBSTRING'
  152. def __init__(self, expression, pos, length=None, **extra):
  153. """
  154. expression: the name of a field, or an expression returning a string
  155. pos: an integer > 0, or an expression returning an integer
  156. length: an optional number of characters to return
  157. """
  158. if not hasattr(pos, 'resolve_expression'):
  159. if pos < 1:
  160. raise ValueError("'pos' must be greater than 0")
  161. expressions = [expression, pos]
  162. if length is not None:
  163. expressions.append(length)
  164. super().__init__(*expressions, **extra)
  165. def as_sqlite(self, compiler, connection):
  166. return super().as_sql(compiler, connection, function='SUBSTR')
  167. def as_oracle(self, compiler, connection):
  168. return super().as_sql(compiler, connection, function='SUBSTR')
  169. class Trim(Transform):
  170. function = 'TRIM'
  171. lookup_name = 'trim'
  172. class Upper(Transform):
  173. function = 'UPPER'
  174. lookup_name = 'upper'