You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

grouping.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. # -*- coding: utf-8 -*-
  2. #
  3. # Copyright (C) 2009-2018 the sqlparse authors and contributors
  4. # <see AUTHORS file>
  5. #
  6. # This module is part of python-sqlparse and is released under
  7. # the BSD License: https://opensource.org/licenses/BSD-3-Clause
  8. from sqlparse import sql
  9. from sqlparse import tokens as T
  10. from sqlparse.utils import recurse, imt
  11. T_NUMERICAL = (T.Number, T.Number.Integer, T.Number.Float)
  12. T_STRING = (T.String, T.String.Single, T.String.Symbol)
  13. T_NAME = (T.Name, T.Name.Placeholder)
  14. def _group_matching(tlist, cls):
  15. """Groups Tokens that have beginning and end."""
  16. opens = []
  17. tidx_offset = 0
  18. for idx, token in enumerate(list(tlist)):
  19. tidx = idx - tidx_offset
  20. if token.is_whitespace:
  21. # ~50% of tokens will be whitespace. Will checking early
  22. # for them avoid 3 comparisons, but then add 1 more comparison
  23. # for the other ~50% of tokens...
  24. continue
  25. if token.is_group and not isinstance(token, cls):
  26. # Check inside previously grouped (i.e. parenthesis) if group
  27. # of different type is inside (i.e., case). though ideally should
  28. # should check for all open/close tokens at once to avoid recursion
  29. _group_matching(token, cls)
  30. continue
  31. if token.match(*cls.M_OPEN):
  32. opens.append(tidx)
  33. elif token.match(*cls.M_CLOSE):
  34. try:
  35. open_idx = opens.pop()
  36. except IndexError:
  37. # this indicates invalid sql and unbalanced tokens.
  38. # instead of break, continue in case other "valid" groups exist
  39. continue
  40. close_idx = tidx
  41. tlist.group_tokens(cls, open_idx, close_idx)
  42. tidx_offset += close_idx - open_idx
  43. def group_brackets(tlist):
  44. _group_matching(tlist, sql.SquareBrackets)
  45. def group_parenthesis(tlist):
  46. _group_matching(tlist, sql.Parenthesis)
  47. def group_case(tlist):
  48. _group_matching(tlist, sql.Case)
  49. def group_if(tlist):
  50. _group_matching(tlist, sql.If)
  51. def group_for(tlist):
  52. _group_matching(tlist, sql.For)
  53. def group_begin(tlist):
  54. _group_matching(tlist, sql.Begin)
  55. def group_typecasts(tlist):
  56. def match(token):
  57. return token.match(T.Punctuation, '::')
  58. def valid(token):
  59. return token is not None
  60. def post(tlist, pidx, tidx, nidx):
  61. return pidx, nidx
  62. valid_prev = valid_next = valid
  63. _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
  64. def group_period(tlist):
  65. def match(token):
  66. return token.match(T.Punctuation, '.')
  67. def valid_prev(token):
  68. sqlcls = sql.SquareBrackets, sql.Identifier
  69. ttypes = T.Name, T.String.Symbol
  70. return imt(token, i=sqlcls, t=ttypes)
  71. def valid_next(token):
  72. # issue261, allow invalid next token
  73. return True
  74. def post(tlist, pidx, tidx, nidx):
  75. # next_ validation is being performed here. issue261
  76. sqlcls = sql.SquareBrackets, sql.Function
  77. ttypes = T.Name, T.String.Symbol, T.Wildcard
  78. next_ = tlist[nidx] if nidx is not None else None
  79. valid_next = imt(next_, i=sqlcls, t=ttypes)
  80. return (pidx, nidx) if valid_next else (pidx, tidx)
  81. _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
  82. def group_as(tlist):
  83. def match(token):
  84. return token.is_keyword and token.normalized == 'AS'
  85. def valid_prev(token):
  86. return token.normalized == 'NULL' or not token.is_keyword
  87. def valid_next(token):
  88. ttypes = T.DML, T.DDL
  89. return not imt(token, t=ttypes) and token is not None
  90. def post(tlist, pidx, tidx, nidx):
  91. return pidx, nidx
  92. _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
  93. def group_assignment(tlist):
  94. def match(token):
  95. return token.match(T.Assignment, ':=')
  96. def valid(token):
  97. return token is not None and token.ttype not in (T.Keyword)
  98. def post(tlist, pidx, tidx, nidx):
  99. m_semicolon = T.Punctuation, ';'
  100. snidx, _ = tlist.token_next_by(m=m_semicolon, idx=nidx)
  101. nidx = snidx or nidx
  102. return pidx, nidx
  103. valid_prev = valid_next = valid
  104. _group(tlist, sql.Assignment, match, valid_prev, valid_next, post)
  105. def group_comparison(tlist):
  106. sqlcls = (sql.Parenthesis, sql.Function, sql.Identifier,
  107. sql.Operation)
  108. ttypes = T_NUMERICAL + T_STRING + T_NAME
  109. def match(token):
  110. return token.ttype == T.Operator.Comparison
  111. def valid(token):
  112. if imt(token, t=ttypes, i=sqlcls):
  113. return True
  114. elif token and token.is_keyword and token.normalized == 'NULL':
  115. return True
  116. else:
  117. return False
  118. def post(tlist, pidx, tidx, nidx):
  119. return pidx, nidx
  120. valid_prev = valid_next = valid
  121. _group(tlist, sql.Comparison, match,
  122. valid_prev, valid_next, post, extend=False)
  123. @recurse(sql.Identifier)
  124. def group_identifier(tlist):
  125. ttypes = (T.String.Symbol, T.Name)
  126. tidx, token = tlist.token_next_by(t=ttypes)
  127. while token:
  128. tlist.group_tokens(sql.Identifier, tidx, tidx)
  129. tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
  130. def group_arrays(tlist):
  131. sqlcls = sql.SquareBrackets, sql.Identifier, sql.Function
  132. ttypes = T.Name, T.String.Symbol
  133. def match(token):
  134. return isinstance(token, sql.SquareBrackets)
  135. def valid_prev(token):
  136. return imt(token, i=sqlcls, t=ttypes)
  137. def valid_next(token):
  138. return True
  139. def post(tlist, pidx, tidx, nidx):
  140. return pidx, tidx
  141. _group(tlist, sql.Identifier, match,
  142. valid_prev, valid_next, post, extend=True, recurse=False)
  143. def group_operator(tlist):
  144. ttypes = T_NUMERICAL + T_STRING + T_NAME
  145. sqlcls = (sql.SquareBrackets, sql.Parenthesis, sql.Function,
  146. sql.Identifier, sql.Operation)
  147. def match(token):
  148. return imt(token, t=(T.Operator, T.Wildcard))
  149. def valid(token):
  150. return imt(token, i=sqlcls, t=ttypes)
  151. def post(tlist, pidx, tidx, nidx):
  152. tlist[tidx].ttype = T.Operator
  153. return pidx, nidx
  154. valid_prev = valid_next = valid
  155. _group(tlist, sql.Operation, match,
  156. valid_prev, valid_next, post, extend=False)
  157. def group_identifier_list(tlist):
  158. m_role = T.Keyword, ('null', 'role')
  159. sqlcls = (sql.Function, sql.Case, sql.Identifier, sql.Comparison,
  160. sql.IdentifierList, sql.Operation)
  161. ttypes = (T_NUMERICAL + T_STRING + T_NAME
  162. + (T.Keyword, T.Comment, T.Wildcard))
  163. def match(token):
  164. return token.match(T.Punctuation, ',')
  165. def valid(token):
  166. return imt(token, i=sqlcls, m=m_role, t=ttypes)
  167. def post(tlist, pidx, tidx, nidx):
  168. return pidx, nidx
  169. valid_prev = valid_next = valid
  170. _group(tlist, sql.IdentifierList, match,
  171. valid_prev, valid_next, post, extend=True)
  172. @recurse(sql.Comment)
  173. def group_comments(tlist):
  174. tidx, token = tlist.token_next_by(t=T.Comment)
  175. while token:
  176. eidx, end = tlist.token_not_matching(
  177. lambda tk: imt(tk, t=T.Comment) or tk.is_whitespace, idx=tidx)
  178. if end is not None:
  179. eidx, end = tlist.token_prev(eidx, skip_ws=False)
  180. tlist.group_tokens(sql.Comment, tidx, eidx)
  181. tidx, token = tlist.token_next_by(t=T.Comment, idx=tidx)
  182. @recurse(sql.Where)
  183. def group_where(tlist):
  184. tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN)
  185. while token:
  186. eidx, end = tlist.token_next_by(m=sql.Where.M_CLOSE, idx=tidx)
  187. if end is None:
  188. end = tlist._groupable_tokens[-1]
  189. else:
  190. end = tlist.tokens[eidx - 1]
  191. # TODO: convert this to eidx instead of end token.
  192. # i think above values are len(tlist) and eidx-1
  193. eidx = tlist.token_index(end)
  194. tlist.group_tokens(sql.Where, tidx, eidx)
  195. tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN, idx=tidx)
  196. @recurse()
  197. def group_aliased(tlist):
  198. I_ALIAS = (sql.Parenthesis, sql.Function, sql.Case, sql.Identifier,
  199. sql.Operation, sql.Comparison)
  200. tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number)
  201. while token:
  202. nidx, next_ = tlist.token_next(tidx)
  203. if isinstance(next_, sql.Identifier):
  204. tlist.group_tokens(sql.Identifier, tidx, nidx, extend=True)
  205. tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=tidx)
  206. @recurse(sql.Function)
  207. def group_functions(tlist):
  208. has_create = False
  209. has_table = False
  210. for tmp_token in tlist.tokens:
  211. if tmp_token.value == 'CREATE':
  212. has_create = True
  213. if tmp_token.value == 'TABLE':
  214. has_table = True
  215. if has_create and has_table:
  216. return
  217. tidx, token = tlist.token_next_by(t=T.Name)
  218. while token:
  219. nidx, next_ = tlist.token_next(tidx)
  220. if isinstance(next_, sql.Parenthesis):
  221. tlist.group_tokens(sql.Function, tidx, nidx)
  222. tidx, token = tlist.token_next_by(t=T.Name, idx=tidx)
  223. def group_order(tlist):
  224. """Group together Identifier and Asc/Desc token"""
  225. tidx, token = tlist.token_next_by(t=T.Keyword.Order)
  226. while token:
  227. pidx, prev_ = tlist.token_prev(tidx)
  228. if imt(prev_, i=sql.Identifier, t=T.Number):
  229. tlist.group_tokens(sql.Identifier, pidx, tidx)
  230. tidx = pidx
  231. tidx, token = tlist.token_next_by(t=T.Keyword.Order, idx=tidx)
  232. @recurse()
  233. def align_comments(tlist):
  234. tidx, token = tlist.token_next_by(i=sql.Comment)
  235. while token:
  236. pidx, prev_ = tlist.token_prev(tidx)
  237. if isinstance(prev_, sql.TokenList):
  238. tlist.group_tokens(sql.TokenList, pidx, tidx, extend=True)
  239. tidx = pidx
  240. tidx, token = tlist.token_next_by(i=sql.Comment, idx=tidx)
  241. def group_values(tlist):
  242. tidx, token = tlist.token_next_by(m=(T.Keyword, 'VALUES'))
  243. start_idx = tidx
  244. end_idx = -1
  245. while token:
  246. if isinstance(token, sql.Parenthesis):
  247. end_idx = tidx
  248. tidx, token = tlist.token_next(tidx)
  249. if end_idx != -1:
  250. tlist.group_tokens(sql.Values, start_idx, end_idx, extend=True)
  251. def group(stmt):
  252. for func in [
  253. group_comments,
  254. # _group_matching
  255. group_brackets,
  256. group_parenthesis,
  257. group_case,
  258. group_if,
  259. group_for,
  260. group_begin,
  261. group_functions,
  262. group_where,
  263. group_period,
  264. group_arrays,
  265. group_identifier,
  266. group_order,
  267. group_typecasts,
  268. group_operator,
  269. group_comparison,
  270. group_as,
  271. group_aliased,
  272. group_assignment,
  273. align_comments,
  274. group_identifier_list,
  275. group_values,
  276. ]:
  277. func(stmt)
  278. return stmt
  279. def _group(tlist, cls, match,
  280. valid_prev=lambda t: True,
  281. valid_next=lambda t: True,
  282. post=None,
  283. extend=True,
  284. recurse=True
  285. ):
  286. """Groups together tokens that are joined by a middle token. i.e. x < y"""
  287. tidx_offset = 0
  288. pidx, prev_ = None, None
  289. for idx, token in enumerate(list(tlist)):
  290. tidx = idx - tidx_offset
  291. if token.is_whitespace:
  292. continue
  293. if recurse and token.is_group and not isinstance(token, cls):
  294. _group(token, cls, match, valid_prev, valid_next, post, extend)
  295. if match(token):
  296. nidx, next_ = tlist.token_next(tidx)
  297. if prev_ and valid_prev(prev_) and valid_next(next_):
  298. from_idx, to_idx = post(tlist, pidx, tidx, nidx)
  299. grp = tlist.group_tokens(cls, from_idx, to_idx, extend=extend)
  300. tidx_offset += to_idx - from_idx
  301. pidx, prev_ = from_idx, grp
  302. continue
  303. pidx, prev_ = tidx, token