Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

transforms.py 3.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) 2015-2016 Claudiu Popa <pcmanticore@gmail.com>
  2. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  3. # For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
  4. import collections
  5. import warnings
  6. class TransformVisitor(object):
  7. """A visitor for handling transforms.
  8. The standard approach of using it is to call
  9. :meth:`~visit` with an *astroid* module and the class
  10. will take care of the rest, walking the tree and running the
  11. transforms for each encountered node.
  12. """
  13. def __init__(self):
  14. self.transforms = collections.defaultdict(list)
  15. def _transform(self, node):
  16. """Call matching transforms for the given node if any and return the
  17. transformed node.
  18. """
  19. cls = node.__class__
  20. if cls not in self.transforms:
  21. # no transform registered for this class of node
  22. return node
  23. transforms = self.transforms[cls]
  24. orig_node = node # copy the reference
  25. for transform_func, predicate in transforms:
  26. if predicate is None or predicate(node):
  27. ret = transform_func(node)
  28. # if the transformation function returns something, it's
  29. # expected to be a replacement for the node
  30. if ret is not None:
  31. if node is not orig_node:
  32. # node has already be modified by some previous
  33. # transformation, warn about it
  34. warnings.warn('node %s substituted multiple times' % node)
  35. node = ret
  36. return node
  37. def _visit(self, node):
  38. if hasattr(node, '_astroid_fields'):
  39. for field in node._astroid_fields:
  40. value = getattr(node, field)
  41. visited = self._visit_generic(value)
  42. setattr(node, field, visited)
  43. return self._transform(node)
  44. def _visit_generic(self, node):
  45. if isinstance(node, list):
  46. return [self._visit_generic(child) for child in node]
  47. elif isinstance(node, tuple):
  48. return tuple(self._visit_generic(child) for child in node)
  49. return self._visit(node)
  50. def register_transform(self, node_class, transform, predicate=None):
  51. """Register `transform(node)` function to be applied on the given
  52. astroid's `node_class` if `predicate` is None or returns true
  53. when called with the node as argument.
  54. The transform function may return a value which is then used to
  55. substitute the original node in the tree.
  56. """
  57. self.transforms[node_class].append((transform, predicate))
  58. def unregister_transform(self, node_class, transform, predicate=None):
  59. """Unregister the given transform."""
  60. self.transforms[node_class].remove((transform, predicate))
  61. def visit(self, module):
  62. """Walk the given astroid *tree* and transform each encountered node
  63. Only the nodes which have transforms registered will actually
  64. be replaced or changed.
  65. """
  66. module.body = [self._visit(child) for child in module.body]
  67. return self._transform(module)