123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- # -*- coding: utf-8 -*-
- #
- # Copyright (C) 2016 Andi Albrecht, albrecht.andi@gmail.com
- #
- # This module is part of python-sqlparse and is released under
- # the BSD License: https://opensource.org/licenses/BSD-3-Clause
-
- from sqlparse import sql, tokens as T
- from sqlparse.compat import text_type
- from sqlparse.utils import offset, indent
-
-
- class ReindentFilter(object):
- def __init__(self, width=2, char=' ', wrap_after=0, n='\n',
- comma_first=False):
- self.n = n
- self.width = width
- self.char = char
- self.indent = 0
- self.offset = 0
- self.wrap_after = wrap_after
- self.comma_first = comma_first
- self._curr_stmt = None
- self._last_stmt = None
-
- def _flatten_up_to_token(self, token):
- """Yields all tokens up to token but excluding current."""
- if token.is_group:
- token = next(token.flatten())
-
- for t in self._curr_stmt.flatten():
- if t == token:
- break
- yield t
-
- @property
- def leading_ws(self):
- return self.offset + self.indent * self.width
-
- def _get_offset(self, token):
- raw = u''.join(map(text_type, self._flatten_up_to_token(token)))
- line = (raw or '\n').splitlines()[-1]
- # Now take current offset into account and return relative offset.
- return len(line) - len(self.char * self.leading_ws)
-
- def nl(self, offset=0):
- return sql.Token(
- T.Whitespace,
- self.n + self.char * max(0, self.leading_ws + offset))
-
- def _next_token(self, tlist, idx=-1):
- split_words = ('FROM', 'STRAIGHT_JOIN$', 'JOIN$', 'AND', 'OR',
- 'GROUP', 'ORDER', 'UNION', 'VALUES',
- 'SET', 'BETWEEN', 'EXCEPT', 'HAVING', 'LIMIT')
- m_split = T.Keyword, split_words, True
- tidx, token = tlist.token_next_by(m=m_split, idx=idx)
-
- if token and token.normalized == 'BETWEEN':
- tidx, token = self._next_token(tlist, tidx)
-
- if token and token.normalized == 'AND':
- tidx, token = self._next_token(tlist, tidx)
-
- return tidx, token
-
- def _split_kwds(self, tlist):
- tidx, token = self._next_token(tlist)
- while token:
- pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
- uprev = text_type(prev_)
-
- if prev_ and prev_.is_whitespace:
- del tlist.tokens[pidx]
- tidx -= 1
-
- if not (uprev.endswith('\n') or uprev.endswith('\r')):
- tlist.insert_before(tidx, self.nl())
- tidx += 1
-
- tidx, token = self._next_token(tlist, tidx)
-
- def _split_statements(self, tlist):
- ttypes = T.Keyword.DML, T.Keyword.DDL
- tidx, token = tlist.token_next_by(t=ttypes)
- while token:
- pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
- if prev_ and prev_.is_whitespace:
- del tlist.tokens[pidx]
- tidx -= 1
- # only break if it's not the first token
- if prev_:
- tlist.insert_before(tidx, self.nl())
- tidx += 1
- tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
-
- def _process(self, tlist):
- func_name = '_process_{cls}'.format(cls=type(tlist).__name__)
- func = getattr(self, func_name.lower(), self._process_default)
- func(tlist)
-
- def _process_where(self, tlist):
- tidx, token = tlist.token_next_by(m=(T.Keyword, 'WHERE'))
- # issue121, errors in statement fixed??
- tlist.insert_before(tidx, self.nl())
-
- with indent(self):
- self._process_default(tlist)
-
- def _process_parenthesis(self, tlist):
- ttypes = T.Keyword.DML, T.Keyword.DDL
- _, is_dml_dll = tlist.token_next_by(t=ttypes)
- fidx, first = tlist.token_next_by(m=sql.Parenthesis.M_OPEN)
-
- with indent(self, 1 if is_dml_dll else 0):
- tlist.tokens.insert(0, self.nl()) if is_dml_dll else None
- with offset(self, self._get_offset(first) + 1):
- self._process_default(tlist, not is_dml_dll)
-
- def _process_identifierlist(self, tlist):
- identifiers = list(tlist.get_identifiers())
- first = next(identifiers.pop(0).flatten())
- num_offset = 1 if self.char == '\t' else self._get_offset(first)
- if not tlist.within(sql.Function):
- with offset(self, num_offset):
- position = 0
- for token in identifiers:
- # Add 1 for the "," separator
- position += len(token.value) + 1
- if position > (self.wrap_after - self.offset):
- adjust = 0
- if self.comma_first:
- adjust = -2
- _, comma = tlist.token_prev(
- tlist.token_index(token))
- if comma is None:
- continue
- token = comma
- tlist.insert_before(token, self.nl(offset=adjust))
- if self.comma_first:
- _, ws = tlist.token_next(
- tlist.token_index(token), skip_ws=False)
- if (ws is not None
- and ws.ttype is not T.Text.Whitespace):
- tlist.insert_after(
- token, sql.Token(T.Whitespace, ' '))
- position = 0
- self._process_default(tlist)
-
- def _process_case(self, tlist):
- iterable = iter(tlist.get_cases())
- cond, _ = next(iterable)
- first = next(cond[0].flatten())
-
- with offset(self, self._get_offset(tlist[0])):
- with offset(self, self._get_offset(first)):
- for cond, value in iterable:
- token = value[0] if cond is None else cond[0]
- tlist.insert_before(token, self.nl())
-
- # Line breaks on group level are done. let's add an offset of
- # len "when ", "then ", "else "
- with offset(self, len("WHEN ")):
- self._process_default(tlist)
- end_idx, end = tlist.token_next_by(m=sql.Case.M_CLOSE)
- if end_idx is not None:
- tlist.insert_before(end_idx, self.nl())
-
- def _process_default(self, tlist, stmts=True):
- self._split_statements(tlist) if stmts else None
- self._split_kwds(tlist)
- for sgroup in tlist.get_sublists():
- self._process(sgroup)
-
- def process(self, stmt):
- self._curr_stmt = stmt
- self._process(stmt)
-
- if self._last_stmt is not None:
- nl = '\n' if text_type(self._last_stmt).endswith('\n') else '\n\n'
- stmt.tokens.insert(0, sql.Token(T.Whitespace, nl))
-
- self._last_stmt = stmt
- return stmt
|