# -*- coding: utf-8 -*- # # Copyright (C) 2009-2018 the sqlparse authors and contributors # # # This module is part of python-sqlparse and is released under # the BSD License: https://opensource.org/licenses/BSD-3-Clause from sqlparse import sql, tokens as T from sqlparse.compat import text_type from sqlparse.utils import offset, indent class ReindentFilter(object): def __init__(self, width=2, char=' ', wrap_after=0, n='\n', comma_first=False, indent_after_first=False, indent_columns=False): self.n = n self.width = width self.char = char self.indent = 1 if indent_after_first else 0 self.offset = 0 self.wrap_after = wrap_after self.comma_first = comma_first self.indent_columns = indent_columns self._curr_stmt = None self._last_stmt = None self._last_func = None def _flatten_up_to_token(self, token): """Yields all tokens up to token but excluding current.""" if token.is_group: token = next(token.flatten()) for t in self._curr_stmt.flatten(): if t == token: break yield t @property def leading_ws(self): return self.offset + self.indent * self.width def _get_offset(self, token): raw = u''.join(map(text_type, self._flatten_up_to_token(token))) line = (raw or '\n').splitlines()[-1] # Now take current offset into account and return relative offset. return len(line) - len(self.char * self.leading_ws) def nl(self, offset=0): return sql.Token( T.Whitespace, self.n + self.char * max(0, self.leading_ws + offset)) def _next_token(self, tlist, idx=-1): split_words = ('FROM', 'STRAIGHT_JOIN$', 'JOIN$', 'AND', 'OR', 'GROUP BY', 'ORDER BY', 'UNION', 'VALUES', 'SET', 'BETWEEN', 'EXCEPT', 'HAVING', 'LIMIT') m_split = T.Keyword, split_words, True tidx, token = tlist.token_next_by(m=m_split, idx=idx) if token and token.normalized == 'BETWEEN': tidx, token = self._next_token(tlist, tidx) if token and token.normalized == 'AND': tidx, token = self._next_token(tlist, tidx) return tidx, token def _split_kwds(self, tlist): tidx, token = self._next_token(tlist) while token: pidx, prev_ = tlist.token_prev(tidx, skip_ws=False) uprev = text_type(prev_) if prev_ and prev_.is_whitespace: del tlist.tokens[pidx] tidx -= 1 if not (uprev.endswith('\n') or uprev.endswith('\r')): tlist.insert_before(tidx, self.nl()) tidx += 1 tidx, token = self._next_token(tlist, tidx) def _split_statements(self, tlist): ttypes = T.Keyword.DML, T.Keyword.DDL tidx, token = tlist.token_next_by(t=ttypes) while token: pidx, prev_ = tlist.token_prev(tidx, skip_ws=False) if prev_ and prev_.is_whitespace: del tlist.tokens[pidx] tidx -= 1 # only break if it's not the first token if prev_: tlist.insert_before(tidx, self.nl()) tidx += 1 tidx, token = tlist.token_next_by(t=ttypes, idx=tidx) def _process(self, tlist): func_name = '_process_{cls}'.format(cls=type(tlist).__name__) func = getattr(self, func_name.lower(), self._process_default) func(tlist) def _process_where(self, tlist): tidx, token = tlist.token_next_by(m=(T.Keyword, 'WHERE')) # issue121, errors in statement fixed?? tlist.insert_before(tidx, self.nl()) with indent(self): self._process_default(tlist) def _process_parenthesis(self, tlist): ttypes = T.Keyword.DML, T.Keyword.DDL _, is_dml_dll = tlist.token_next_by(t=ttypes) fidx, first = tlist.token_next_by(m=sql.Parenthesis.M_OPEN) with indent(self, 1 if is_dml_dll else 0): tlist.tokens.insert(0, self.nl()) if is_dml_dll else None with offset(self, self._get_offset(first) + 1): self._process_default(tlist, not is_dml_dll) def _process_function(self, tlist): self._last_func = tlist[0] self._process_default(tlist) def _process_identifierlist(self, tlist): identifiers = list(tlist.get_identifiers()) if self.indent_columns: first = next(identifiers[0].flatten()) num_offset = 1 if self.char == '\t' else self.width else: first = next(identifiers.pop(0).flatten()) num_offset = 1 if self.char == '\t' else self._get_offset(first) if not tlist.within(sql.Function) and not tlist.within(sql.Values): with offset(self, num_offset): position = 0 for token in identifiers: # Add 1 for the "," separator position += len(token.value) + 1 if position > (self.wrap_after - self.offset): adjust = 0 if self.comma_first: adjust = -2 _, comma = tlist.token_prev( tlist.token_index(token)) if comma is None: continue token = comma tlist.insert_before(token, self.nl(offset=adjust)) if self.comma_first: _, ws = tlist.token_next( tlist.token_index(token), skip_ws=False) if (ws is not None and ws.ttype is not T.Text.Whitespace): tlist.insert_after( token, sql.Token(T.Whitespace, ' ')) position = 0 else: # ensure whitespace for token in tlist: _, next_ws = tlist.token_next( tlist.token_index(token), skip_ws=False) if token.value == ',' and not next_ws.is_whitespace: tlist.insert_after( token, sql.Token(T.Whitespace, ' ')) end_at = self.offset + sum(len(i.value) + 1 for i in identifiers) adjusted_offset = 0 if (self.wrap_after > 0 and end_at > (self.wrap_after - self.offset) and self._last_func): adjusted_offset = -len(self._last_func.value) - 1 with offset(self, adjusted_offset), indent(self): if adjusted_offset < 0: tlist.insert_before(identifiers[0], self.nl()) position = 0 for token in identifiers: # Add 1 for the "," separator position += len(token.value) + 1 if (self.wrap_after > 0 and position > (self.wrap_after - self.offset)): adjust = 0 tlist.insert_before(token, self.nl(offset=adjust)) position = 0 self._process_default(tlist) def _process_case(self, tlist): iterable = iter(tlist.get_cases()) cond, _ = next(iterable) first = next(cond[0].flatten()) with offset(self, self._get_offset(tlist[0])): with offset(self, self._get_offset(first)): for cond, value in iterable: token = value[0] if cond is None else cond[0] tlist.insert_before(token, self.nl()) # Line breaks on group level are done. let's add an offset of # len "when ", "then ", "else " with offset(self, len("WHEN ")): self._process_default(tlist) end_idx, end = tlist.token_next_by(m=sql.Case.M_CLOSE) if end_idx is not None: tlist.insert_before(end_idx, self.nl()) def _process_values(self, tlist): tlist.insert_before(0, self.nl()) tidx, token = tlist.token_next_by(i=sql.Parenthesis) first_token = token while token: ptidx, ptoken = tlist.token_next_by(m=(T.Punctuation, ','), idx=tidx) if ptoken: if self.comma_first: adjust = -2 offset = self._get_offset(first_token) + adjust tlist.insert_before(ptoken, self.nl(offset)) else: tlist.insert_after(ptoken, self.nl(self._get_offset(token))) tidx, token = tlist.token_next_by(i=sql.Parenthesis, idx=tidx) def _process_default(self, tlist, stmts=True): self._split_statements(tlist) if stmts else None self._split_kwds(tlist) for sgroup in tlist.get_sublists(): self._process(sgroup) def process(self, stmt): self._curr_stmt = stmt self._process(stmt) if self._last_stmt is not None: nl = '\n' if text_type(self._last_stmt).endswith('\n') else '\n\n' stmt.tokens.insert(0, sql.Token(T.Whitespace, nl)) self._last_stmt = stmt return stmt