You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

reindent.py 9.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # -*- coding: utf-8 -*-
  2. #
  3. # Copyright (C) 2009-2018 the sqlparse authors and contributors
  4. # <see AUTHORS file>
  5. #
  6. # This module is part of python-sqlparse and is released under
  7. # the BSD License: https://opensource.org/licenses/BSD-3-Clause
  8. from sqlparse import sql, tokens as T
  9. from sqlparse.compat import text_type
  10. from sqlparse.utils import offset, indent
  11. class ReindentFilter(object):
  12. def __init__(self, width=2, char=' ', wrap_after=0, n='\n',
  13. comma_first=False, indent_after_first=False,
  14. indent_columns=False):
  15. self.n = n
  16. self.width = width
  17. self.char = char
  18. self.indent = 1 if indent_after_first else 0
  19. self.offset = 0
  20. self.wrap_after = wrap_after
  21. self.comma_first = comma_first
  22. self.indent_columns = indent_columns
  23. self._curr_stmt = None
  24. self._last_stmt = None
  25. self._last_func = None
  26. def _flatten_up_to_token(self, token):
  27. """Yields all tokens up to token but excluding current."""
  28. if token.is_group:
  29. token = next(token.flatten())
  30. for t in self._curr_stmt.flatten():
  31. if t == token:
  32. break
  33. yield t
  34. @property
  35. def leading_ws(self):
  36. return self.offset + self.indent * self.width
  37. def _get_offset(self, token):
  38. raw = u''.join(map(text_type, self._flatten_up_to_token(token)))
  39. line = (raw or '\n').splitlines()[-1]
  40. # Now take current offset into account and return relative offset.
  41. return len(line) - len(self.char * self.leading_ws)
  42. def nl(self, offset=0):
  43. return sql.Token(
  44. T.Whitespace,
  45. self.n + self.char * max(0, self.leading_ws + offset))
  46. def _next_token(self, tlist, idx=-1):
  47. split_words = ('FROM', 'STRAIGHT_JOIN$', 'JOIN$', 'AND', 'OR',
  48. 'GROUP BY', 'ORDER BY', 'UNION', 'VALUES',
  49. 'SET', 'BETWEEN', 'EXCEPT', 'HAVING', 'LIMIT')
  50. m_split = T.Keyword, split_words, True
  51. tidx, token = tlist.token_next_by(m=m_split, idx=idx)
  52. if token and token.normalized == 'BETWEEN':
  53. tidx, token = self._next_token(tlist, tidx)
  54. if token and token.normalized == 'AND':
  55. tidx, token = self._next_token(tlist, tidx)
  56. return tidx, token
  57. def _split_kwds(self, tlist):
  58. tidx, token = self._next_token(tlist)
  59. while token:
  60. pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
  61. uprev = text_type(prev_)
  62. if prev_ and prev_.is_whitespace:
  63. del tlist.tokens[pidx]
  64. tidx -= 1
  65. if not (uprev.endswith('\n') or uprev.endswith('\r')):
  66. tlist.insert_before(tidx, self.nl())
  67. tidx += 1
  68. tidx, token = self._next_token(tlist, tidx)
  69. def _split_statements(self, tlist):
  70. ttypes = T.Keyword.DML, T.Keyword.DDL
  71. tidx, token = tlist.token_next_by(t=ttypes)
  72. while token:
  73. pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
  74. if prev_ and prev_.is_whitespace:
  75. del tlist.tokens[pidx]
  76. tidx -= 1
  77. # only break if it's not the first token
  78. if prev_:
  79. tlist.insert_before(tidx, self.nl())
  80. tidx += 1
  81. tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
  82. def _process(self, tlist):
  83. func_name = '_process_{cls}'.format(cls=type(tlist).__name__)
  84. func = getattr(self, func_name.lower(), self._process_default)
  85. func(tlist)
  86. def _process_where(self, tlist):
  87. tidx, token = tlist.token_next_by(m=(T.Keyword, 'WHERE'))
  88. # issue121, errors in statement fixed??
  89. tlist.insert_before(tidx, self.nl())
  90. with indent(self):
  91. self._process_default(tlist)
  92. def _process_parenthesis(self, tlist):
  93. ttypes = T.Keyword.DML, T.Keyword.DDL
  94. _, is_dml_dll = tlist.token_next_by(t=ttypes)
  95. fidx, first = tlist.token_next_by(m=sql.Parenthesis.M_OPEN)
  96. with indent(self, 1 if is_dml_dll else 0):
  97. tlist.tokens.insert(0, self.nl()) if is_dml_dll else None
  98. with offset(self, self._get_offset(first) + 1):
  99. self._process_default(tlist, not is_dml_dll)
  100. def _process_function(self, tlist):
  101. self._last_func = tlist[0]
  102. self._process_default(tlist)
  103. def _process_identifierlist(self, tlist):
  104. identifiers = list(tlist.get_identifiers())
  105. if self.indent_columns:
  106. first = next(identifiers[0].flatten())
  107. num_offset = 1 if self.char == '\t' else self.width
  108. else:
  109. first = next(identifiers.pop(0).flatten())
  110. num_offset = 1 if self.char == '\t' else self._get_offset(first)
  111. if not tlist.within(sql.Function) and not tlist.within(sql.Values):
  112. with offset(self, num_offset):
  113. position = 0
  114. for token in identifiers:
  115. # Add 1 for the "," separator
  116. position += len(token.value) + 1
  117. if position > (self.wrap_after - self.offset):
  118. adjust = 0
  119. if self.comma_first:
  120. adjust = -2
  121. _, comma = tlist.token_prev(
  122. tlist.token_index(token))
  123. if comma is None:
  124. continue
  125. token = comma
  126. tlist.insert_before(token, self.nl(offset=adjust))
  127. if self.comma_first:
  128. _, ws = tlist.token_next(
  129. tlist.token_index(token), skip_ws=False)
  130. if (ws is not None
  131. and ws.ttype is not T.Text.Whitespace):
  132. tlist.insert_after(
  133. token, sql.Token(T.Whitespace, ' '))
  134. position = 0
  135. else:
  136. # ensure whitespace
  137. for token in tlist:
  138. _, next_ws = tlist.token_next(
  139. tlist.token_index(token), skip_ws=False)
  140. if token.value == ',' and not next_ws.is_whitespace:
  141. tlist.insert_after(
  142. token, sql.Token(T.Whitespace, ' '))
  143. end_at = self.offset + sum(len(i.value) + 1 for i in identifiers)
  144. adjusted_offset = 0
  145. if (self.wrap_after > 0
  146. and end_at > (self.wrap_after - self.offset)
  147. and self._last_func):
  148. adjusted_offset = -len(self._last_func.value) - 1
  149. with offset(self, adjusted_offset), indent(self):
  150. if adjusted_offset < 0:
  151. tlist.insert_before(identifiers[0], self.nl())
  152. position = 0
  153. for token in identifiers:
  154. # Add 1 for the "," separator
  155. position += len(token.value) + 1
  156. if (self.wrap_after > 0
  157. and position > (self.wrap_after - self.offset)):
  158. adjust = 0
  159. tlist.insert_before(token, self.nl(offset=adjust))
  160. position = 0
  161. self._process_default(tlist)
  162. def _process_case(self, tlist):
  163. iterable = iter(tlist.get_cases())
  164. cond, _ = next(iterable)
  165. first = next(cond[0].flatten())
  166. with offset(self, self._get_offset(tlist[0])):
  167. with offset(self, self._get_offset(first)):
  168. for cond, value in iterable:
  169. token = value[0] if cond is None else cond[0]
  170. tlist.insert_before(token, self.nl())
  171. # Line breaks on group level are done. let's add an offset of
  172. # len "when ", "then ", "else "
  173. with offset(self, len("WHEN ")):
  174. self._process_default(tlist)
  175. end_idx, end = tlist.token_next_by(m=sql.Case.M_CLOSE)
  176. if end_idx is not None:
  177. tlist.insert_before(end_idx, self.nl())
  178. def _process_values(self, tlist):
  179. tlist.insert_before(0, self.nl())
  180. tidx, token = tlist.token_next_by(i=sql.Parenthesis)
  181. first_token = token
  182. while token:
  183. ptidx, ptoken = tlist.token_next_by(m=(T.Punctuation, ','),
  184. idx=tidx)
  185. if ptoken:
  186. if self.comma_first:
  187. adjust = -2
  188. offset = self._get_offset(first_token) + adjust
  189. tlist.insert_before(ptoken, self.nl(offset))
  190. else:
  191. tlist.insert_after(ptoken,
  192. self.nl(self._get_offset(token)))
  193. tidx, token = tlist.token_next_by(i=sql.Parenthesis, idx=tidx)
  194. def _process_default(self, tlist, stmts=True):
  195. self._split_statements(tlist) if stmts else None
  196. self._split_kwds(tlist)
  197. for sgroup in tlist.get_sublists():
  198. self._process(sgroup)
  199. def process(self, stmt):
  200. self._curr_stmt = stmt
  201. self._process(stmt)
  202. if self._last_stmt is not None:
  203. nl = '\n' if text_type(self._last_stmt).endswith('\n') else '\n\n'
  204. stmt.tokens.insert(0, sql.Token(T.Whitespace, nl))
  205. self._last_stmt = stmt
  206. return stmt