Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

reindent.py 6.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # -*- coding: utf-8 -*-
  2. #
  3. # Copyright (C) 2016 Andi Albrecht, albrecht.andi@gmail.com
  4. #
  5. # This module is part of python-sqlparse and is released under
  6. # the BSD License: https://opensource.org/licenses/BSD-3-Clause
  7. from sqlparse import sql, tokens as T
  8. from sqlparse.compat import text_type
  9. from sqlparse.utils import offset, indent
  10. class ReindentFilter(object):
  11. def __init__(self, width=2, char=' ', wrap_after=0, n='\n',
  12. comma_first=False):
  13. self.n = n
  14. self.width = width
  15. self.char = char
  16. self.indent = 0
  17. self.offset = 0
  18. self.wrap_after = wrap_after
  19. self.comma_first = comma_first
  20. self._curr_stmt = None
  21. self._last_stmt = None
  22. def _flatten_up_to_token(self, token):
  23. """Yields all tokens up to token but excluding current."""
  24. if token.is_group:
  25. token = next(token.flatten())
  26. for t in self._curr_stmt.flatten():
  27. if t == token:
  28. break
  29. yield t
  30. @property
  31. def leading_ws(self):
  32. return self.offset + self.indent * self.width
  33. def _get_offset(self, token):
  34. raw = u''.join(map(text_type, self._flatten_up_to_token(token)))
  35. line = (raw or '\n').splitlines()[-1]
  36. # Now take current offset into account and return relative offset.
  37. return len(line) - len(self.char * self.leading_ws)
  38. def nl(self, offset=0):
  39. return sql.Token(
  40. T.Whitespace,
  41. self.n + self.char * max(0, self.leading_ws + offset))
  42. def _next_token(self, tlist, idx=-1):
  43. split_words = ('FROM', 'STRAIGHT_JOIN$', 'JOIN$', 'AND', 'OR',
  44. 'GROUP', 'ORDER', 'UNION', 'VALUES',
  45. 'SET', 'BETWEEN', 'EXCEPT', 'HAVING', 'LIMIT')
  46. m_split = T.Keyword, split_words, True
  47. tidx, token = tlist.token_next_by(m=m_split, idx=idx)
  48. if token and token.normalized == 'BETWEEN':
  49. tidx, token = self._next_token(tlist, tidx)
  50. if token and token.normalized == 'AND':
  51. tidx, token = self._next_token(tlist, tidx)
  52. return tidx, token
  53. def _split_kwds(self, tlist):
  54. tidx, token = self._next_token(tlist)
  55. while token:
  56. pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
  57. uprev = text_type(prev_)
  58. if prev_ and prev_.is_whitespace:
  59. del tlist.tokens[pidx]
  60. tidx -= 1
  61. if not (uprev.endswith('\n') or uprev.endswith('\r')):
  62. tlist.insert_before(tidx, self.nl())
  63. tidx += 1
  64. tidx, token = self._next_token(tlist, tidx)
  65. def _split_statements(self, tlist):
  66. ttypes = T.Keyword.DML, T.Keyword.DDL
  67. tidx, token = tlist.token_next_by(t=ttypes)
  68. while token:
  69. pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
  70. if prev_ and prev_.is_whitespace:
  71. del tlist.tokens[pidx]
  72. tidx -= 1
  73. # only break if it's not the first token
  74. if prev_:
  75. tlist.insert_before(tidx, self.nl())
  76. tidx += 1
  77. tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
  78. def _process(self, tlist):
  79. func_name = '_process_{cls}'.format(cls=type(tlist).__name__)
  80. func = getattr(self, func_name.lower(), self._process_default)
  81. func(tlist)
  82. def _process_where(self, tlist):
  83. tidx, token = tlist.token_next_by(m=(T.Keyword, 'WHERE'))
  84. # issue121, errors in statement fixed??
  85. tlist.insert_before(tidx, self.nl())
  86. with indent(self):
  87. self._process_default(tlist)
  88. def _process_parenthesis(self, tlist):
  89. ttypes = T.Keyword.DML, T.Keyword.DDL
  90. _, is_dml_dll = tlist.token_next_by(t=ttypes)
  91. fidx, first = tlist.token_next_by(m=sql.Parenthesis.M_OPEN)
  92. with indent(self, 1 if is_dml_dll else 0):
  93. tlist.tokens.insert(0, self.nl()) if is_dml_dll else None
  94. with offset(self, self._get_offset(first) + 1):
  95. self._process_default(tlist, not is_dml_dll)
  96. def _process_identifierlist(self, tlist):
  97. identifiers = list(tlist.get_identifiers())
  98. first = next(identifiers.pop(0).flatten())
  99. num_offset = 1 if self.char == '\t' else self._get_offset(first)
  100. if not tlist.within(sql.Function):
  101. with offset(self, num_offset):
  102. position = 0
  103. for token in identifiers:
  104. # Add 1 for the "," separator
  105. position += len(token.value) + 1
  106. if position > (self.wrap_after - self.offset):
  107. adjust = 0
  108. if self.comma_first:
  109. adjust = -2
  110. _, comma = tlist.token_prev(
  111. tlist.token_index(token))
  112. if comma is None:
  113. continue
  114. token = comma
  115. tlist.insert_before(token, self.nl(offset=adjust))
  116. if self.comma_first:
  117. _, ws = tlist.token_next(
  118. tlist.token_index(token), skip_ws=False)
  119. if (ws is not None
  120. and ws.ttype is not T.Text.Whitespace):
  121. tlist.insert_after(
  122. token, sql.Token(T.Whitespace, ' '))
  123. position = 0
  124. self._process_default(tlist)
  125. def _process_case(self, tlist):
  126. iterable = iter(tlist.get_cases())
  127. cond, _ = next(iterable)
  128. first = next(cond[0].flatten())
  129. with offset(self, self._get_offset(tlist[0])):
  130. with offset(self, self._get_offset(first)):
  131. for cond, value in iterable:
  132. token = value[0] if cond is None else cond[0]
  133. tlist.insert_before(token, self.nl())
  134. # Line breaks on group level are done. let's add an offset of
  135. # len "when ", "then ", "else "
  136. with offset(self, len("WHEN ")):
  137. self._process_default(tlist)
  138. end_idx, end = tlist.token_next_by(m=sql.Case.M_CLOSE)
  139. if end_idx is not None:
  140. tlist.insert_before(end_idx, self.nl())
  141. def _process_default(self, tlist, stmts=True):
  142. self._split_statements(tlist) if stmts else None
  143. self._split_kwds(tlist)
  144. for sgroup in tlist.get_sublists():
  145. self._process(sgroup)
  146. def process(self, stmt):
  147. self._curr_stmt = stmt
  148. self._process(stmt)
  149. if self._last_stmt is not None:
  150. nl = '\n' if text_type(self._last_stmt).endswith('\n') else '\n\n'
  151. stmt.tokens.insert(0, sql.Token(T.Whitespace, nl))
  152. self._last_stmt = stmt
  153. return stmt