Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

unittest_protocols.py 7.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # Copyright (c) 2015-2016 Claudiu Popa <pcmanticore@gmail.com>
  2. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  3. # For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
  4. import contextlib
  5. import unittest
  6. import astroid
  7. from astroid import extract_node
  8. from astroid.test_utils import require_version
  9. from astroid import InferenceError
  10. from astroid import nodes
  11. from astroid import util
  12. from astroid.node_classes import AssignName, Const, Name, Starred
  13. @contextlib.contextmanager
  14. def _add_transform(manager, node, transform, predicate=None):
  15. manager.register_transform(node, transform, predicate)
  16. try:
  17. yield
  18. finally:
  19. manager.unregister_transform(node, transform, predicate)
  20. class ProtocolTests(unittest.TestCase):
  21. def assertConstNodesEqual(self, nodes_list_expected, nodes_list_got):
  22. self.assertEqual(len(nodes_list_expected), len(nodes_list_got))
  23. for node in nodes_list_got:
  24. self.assertIsInstance(node, Const)
  25. for node, expected_value in zip(nodes_list_got, nodes_list_expected):
  26. self.assertEqual(expected_value, node.value)
  27. def assertNameNodesEqual(self, nodes_list_expected, nodes_list_got):
  28. self.assertEqual(len(nodes_list_expected), len(nodes_list_got))
  29. for node in nodes_list_got:
  30. self.assertIsInstance(node, Name)
  31. for node, expected_name in zip(nodes_list_got, nodes_list_expected):
  32. self.assertEqual(expected_name, node.name)
  33. def test_assigned_stmts_simple_for(self):
  34. assign_stmts = extract_node("""
  35. for a in (1, 2, 3): #@
  36. pass
  37. for b in range(3): #@
  38. pass
  39. """)
  40. for1_assnode = next(assign_stmts[0].nodes_of_class(AssignName))
  41. assigned = list(for1_assnode.assigned_stmts())
  42. self.assertConstNodesEqual([1, 2, 3], assigned)
  43. for2_assnode = next(assign_stmts[1].nodes_of_class(AssignName))
  44. self.assertRaises(InferenceError,
  45. list, for2_assnode.assigned_stmts())
  46. @require_version(minver='3.0')
  47. def test_assigned_stmts_starred_for(self):
  48. assign_stmts = extract_node("""
  49. for *a, b in ((1, 2, 3), (4, 5, 6, 7)): #@
  50. pass
  51. """)
  52. for1_starred = next(assign_stmts.nodes_of_class(Starred))
  53. assigned = next(for1_starred.assigned_stmts())
  54. self.assertEqual(assigned, util.Uninferable)
  55. def _get_starred_stmts(self, code):
  56. assign_stmt = extract_node("{} #@".format(code))
  57. starred = next(assign_stmt.nodes_of_class(Starred))
  58. return next(starred.assigned_stmts())
  59. def _helper_starred_expected_const(self, code, expected):
  60. stmts = self._get_starred_stmts(code)
  61. self.assertIsInstance(stmts, nodes.List)
  62. stmts = stmts.elts
  63. self.assertConstNodesEqual(expected, stmts)
  64. def _helper_starred_expected(self, code, expected):
  65. stmts = self._get_starred_stmts(code)
  66. self.assertEqual(expected, stmts)
  67. def _helper_starred_inference_error(self, code):
  68. assign_stmt = extract_node("{} #@".format(code))
  69. starred = next(assign_stmt.nodes_of_class(Starred))
  70. self.assertRaises(InferenceError, list, starred.assigned_stmts())
  71. @require_version(minver='3.0')
  72. def test_assigned_stmts_starred_assnames(self):
  73. self._helper_starred_expected_const(
  74. "a, *b = (1, 2, 3, 4) #@", [2, 3, 4])
  75. self._helper_starred_expected_const(
  76. "*a, b = (1, 2, 3) #@", [1, 2])
  77. self._helper_starred_expected_const(
  78. "a, *b, c = (1, 2, 3, 4, 5) #@",
  79. [2, 3, 4])
  80. self._helper_starred_expected_const(
  81. "a, *b = (1, 2) #@", [2])
  82. self._helper_starred_expected_const(
  83. "*b, a = (1, 2) #@", [1])
  84. self._helper_starred_expected_const(
  85. "[*b] = (1, 2) #@", [1, 2])
  86. @require_version(minver='3.0')
  87. def test_assigned_stmts_starred_yes(self):
  88. # Not something iterable and known
  89. self._helper_starred_expected("a, *b = range(3) #@", util.Uninferable)
  90. # Not something inferrable
  91. self._helper_starred_expected("a, *b = balou() #@", util.Uninferable)
  92. # In function, unknown.
  93. self._helper_starred_expected("""
  94. def test(arg):
  95. head, *tail = arg #@""", util.Uninferable)
  96. # These cases aren't worth supporting.
  97. self._helper_starred_expected(
  98. "a, (*b, c), d = (1, (2, 3, 4), 5) #@", util.Uninferable)
  99. @require_version(minver='3.0')
  100. def test_assign_stmts_starred_fails(self):
  101. # Too many starred
  102. self._helper_starred_inference_error("a, *b, *c = (1, 2, 3) #@")
  103. # Too many lhs values
  104. self._helper_starred_inference_error("a, *b, c = (1, 2) #@")
  105. # This could be solved properly, but it complicates needlessly the
  106. # code for assigned_stmts, without offering real benefit.
  107. self._helper_starred_inference_error(
  108. "(*a, b), (c, *d) = (1, 2, 3), (4, 5, 6) #@")
  109. def test_assigned_stmts_assignments(self):
  110. assign_stmts = extract_node("""
  111. c = a #@
  112. d, e = b, c #@
  113. """)
  114. simple_assnode = next(assign_stmts[0].nodes_of_class(AssignName))
  115. assigned = list(simple_assnode.assigned_stmts())
  116. self.assertNameNodesEqual(['a'], assigned)
  117. assnames = assign_stmts[1].nodes_of_class(AssignName)
  118. simple_mul_assnode_1 = next(assnames)
  119. assigned = list(simple_mul_assnode_1.assigned_stmts())
  120. self.assertNameNodesEqual(['b'], assigned)
  121. simple_mul_assnode_2 = next(assnames)
  122. assigned = list(simple_mul_assnode_2.assigned_stmts())
  123. self.assertNameNodesEqual(['c'], assigned)
  124. @require_version(minver='3.6')
  125. def test_assigned_stmts_annassignments(self):
  126. annassign_stmts = extract_node("""
  127. a: str = "abc" #@
  128. b: str #@
  129. """)
  130. simple_annassign_node = next(annassign_stmts[0].nodes_of_class(AssignName))
  131. assigned = list(simple_annassign_node.assigned_stmts())
  132. self.assertEqual(1, len(assigned))
  133. self.assertIsInstance(assigned[0], Const)
  134. self.assertEqual(assigned[0].value, "abc")
  135. empty_annassign_node = next(annassign_stmts[1].nodes_of_class(AssignName))
  136. assigned = list(empty_annassign_node.assigned_stmts())
  137. self.assertEqual(1, len(assigned))
  138. self.assertIs(assigned[0], util.Uninferable)
  139. def test_sequence_assigned_stmts_not_accepting_empty_node(self):
  140. def transform(node):
  141. node.root().locals['__all__'] = [node.value]
  142. manager = astroid.MANAGER
  143. with _add_transform(manager, astroid.Assign, transform):
  144. module = astroid.parse('''
  145. __all__ = ['a']
  146. ''')
  147. module.wildcard_import_names()
  148. def test_not_passing_uninferable_in_seq_inference(self):
  149. class Visitor(object):
  150. def visit(self, node):
  151. for child in node.get_children():
  152. child.accept(self)
  153. visit_module = visit
  154. visit_assign = visit
  155. visit_binop = visit
  156. visit_list = visit
  157. visit_const = visit
  158. visit_name = visit
  159. def visit_assignname(self, node):
  160. for _ in node.infer():
  161. pass
  162. parsed = extract_node("""
  163. a = []
  164. x = [a*2, a]*2*2
  165. """)
  166. parsed.accept(Visitor())
  167. if __name__ == '__main__':
  168. unittest.main()