Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

unittest_transforms.py 7.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # Copyright (c) 2015-2016 Cara Vinson <ceridwenv@gmail.com>
  2. # Copyright (c) 2015-2016 Claudiu Popa <pcmanticore@gmail.com>
  3. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  4. # For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
  5. from __future__ import print_function
  6. import contextlib
  7. import time
  8. import unittest
  9. from astroid import builder
  10. from astroid import nodes
  11. from astroid import parse
  12. from astroid import transforms
  13. @contextlib.contextmanager
  14. def add_transform(manager, node, transform, predicate=None):
  15. manager.register_transform(node, transform, predicate)
  16. try:
  17. yield
  18. finally:
  19. manager.unregister_transform(node, transform, predicate)
  20. class TestTransforms(unittest.TestCase):
  21. def setUp(self):
  22. self.transformer = transforms.TransformVisitor()
  23. def parse_transform(self, code):
  24. module = parse(code, apply_transforms=False)
  25. return self.transformer.visit(module)
  26. def test_function_inlining_transform(self):
  27. def transform_call(node):
  28. # Let's do some function inlining
  29. inferred = next(node.infer())
  30. return inferred
  31. self.transformer.register_transform(nodes.Call,
  32. transform_call)
  33. module = self.parse_transform('''
  34. def test(): return 42
  35. test() #@
  36. ''')
  37. self.assertIsInstance(module.body[1], nodes.Expr)
  38. self.assertIsInstance(module.body[1].value, nodes.Const)
  39. self.assertEqual(module.body[1].value.value, 42)
  40. def test_recursive_transforms_into_astroid_fields(self):
  41. # Test that the transformer walks properly the tree
  42. # by going recursively into the _astroid_fields per each node.
  43. def transform_compare(node):
  44. # Let's check the values of the ops
  45. _, right = node.ops[0]
  46. # Assume they are Consts and they were transformed before
  47. # us.
  48. return nodes.const_factory(node.left.value < right.value)
  49. def transform_name(node):
  50. # Should be Consts
  51. return next(node.infer())
  52. self.transformer.register_transform(nodes.Compare, transform_compare)
  53. self.transformer.register_transform(nodes.Name, transform_name)
  54. module = self.parse_transform('''
  55. a = 42
  56. b = 24
  57. a < b
  58. ''')
  59. self.assertIsInstance(module.body[2], nodes.Expr)
  60. self.assertIsInstance(module.body[2].value, nodes.Const)
  61. self.assertFalse(module.body[2].value.value)
  62. def test_transform_patches_locals(self):
  63. def transform_function(node):
  64. assign = nodes.Assign()
  65. name = nodes.AssignName()
  66. name.name = 'value'
  67. assign.targets = [name]
  68. assign.value = nodes.const_factory(42)
  69. node.body.append(assign)
  70. self.transformer.register_transform(nodes.FunctionDef,
  71. transform_function)
  72. module = self.parse_transform('''
  73. def test():
  74. pass
  75. ''')
  76. func = module.body[0]
  77. self.assertEqual(len(func.body), 2)
  78. self.assertIsInstance(func.body[1], nodes.Assign)
  79. self.assertEqual(func.body[1].as_string(), 'value = 42')
  80. def test_predicates(self):
  81. def transform_call(node):
  82. inferred = next(node.infer())
  83. return inferred
  84. def should_inline(node):
  85. return node.func.name.startswith('inlineme')
  86. self.transformer.register_transform(nodes.Call,
  87. transform_call,
  88. should_inline)
  89. module = self.parse_transform('''
  90. def inlineme_1():
  91. return 24
  92. def dont_inline_me():
  93. return 42
  94. def inlineme_2():
  95. return 2
  96. inlineme_1()
  97. dont_inline_me()
  98. inlineme_2()
  99. ''')
  100. values = module.body[-3:]
  101. self.assertIsInstance(values[0], nodes.Expr)
  102. self.assertIsInstance(values[0].value, nodes.Const)
  103. self.assertEqual(values[0].value.value, 24)
  104. self.assertIsInstance(values[1], nodes.Expr)
  105. self.assertIsInstance(values[1].value, nodes.Call)
  106. self.assertIsInstance(values[2], nodes.Expr)
  107. self.assertIsInstance(values[2].value, nodes.Const)
  108. self.assertEqual(values[2].value.value, 2)
  109. def test_transforms_are_separated(self):
  110. # Test that the transforming is done at a separate
  111. # step, which means that we are not doing inference
  112. # on a partially constructed tree anymore, which was the
  113. # source of crashes in the past when certain inference rules
  114. # were used in a transform.
  115. def transform_function(node):
  116. if node.decorators:
  117. for decorator in node.decorators.nodes:
  118. inferred = next(decorator.infer())
  119. if inferred.qname() == 'abc.abstractmethod':
  120. return next(node.infer_call_result(node))
  121. return None
  122. manager = builder.MANAGER
  123. with add_transform(manager, nodes.FunctionDef, transform_function):
  124. module = builder.parse('''
  125. import abc
  126. from abc import abstractmethod
  127. class A(object):
  128. @abc.abstractmethod
  129. def ala(self):
  130. return 24
  131. @abstractmethod
  132. def bala(self):
  133. return 42
  134. ''')
  135. cls = module['A']
  136. ala = cls.body[0]
  137. bala = cls.body[1]
  138. self.assertIsInstance(ala, nodes.Const)
  139. self.assertEqual(ala.value, 24)
  140. self.assertIsInstance(bala, nodes.Const)
  141. self.assertEqual(bala.value, 42)
  142. def test_transforms_are_called_for_builtin_modules(self):
  143. # Test that transforms are called for builtin modules.
  144. def transform_function(node):
  145. name = nodes.AssignName()
  146. name.name = 'value'
  147. node.args.args = [name]
  148. return node
  149. manager = builder.MANAGER
  150. predicate = lambda node: node.root().name == 'time'
  151. with add_transform(manager, nodes.FunctionDef,
  152. transform_function, predicate):
  153. builder_instance = builder.AstroidBuilder()
  154. module = builder_instance.module_build(time)
  155. asctime = module['asctime']
  156. self.assertEqual(len(asctime.args.args), 1)
  157. self.assertIsInstance(asctime.args.args[0], nodes.AssignName)
  158. self.assertEqual(asctime.args.args[0].name, 'value')
  159. def test_builder_apply_transforms(self):
  160. def transform_function(node):
  161. return nodes.const_factory(42)
  162. manager = builder.MANAGER
  163. with add_transform(manager, nodes.FunctionDef, transform_function):
  164. astroid_builder = builder.AstroidBuilder(apply_transforms=False)
  165. module = astroid_builder.string_build('''def test(): pass''')
  166. # The transform wasn't applied.
  167. self.assertIsInstance(module.body[0], nodes.FunctionDef)
  168. def test_transform_crashes_on_is_subtype_of(self):
  169. # Test that we don't crash when having is_subtype_of
  170. # in a transform, as per issue #188. This happened
  171. # before, when the transforms weren't in their own step.
  172. def transform_class(cls):
  173. if cls.is_subtype_of('django.db.models.base.Model'):
  174. return cls
  175. return cls
  176. self.transformer.register_transform(nodes.ClassDef,
  177. transform_class)
  178. self.parse_transform('''
  179. # Change environ to automatically call putenv() if it exists
  180. import os
  181. putenv = os.putenv
  182. try:
  183. # This will fail if there's no putenv
  184. putenv
  185. except NameError:
  186. pass
  187. else:
  188. import UserDict
  189. ''')
  190. if __name__ == '__main__':
  191. unittest.main()