Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

brain_random.py 2.6KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
  3. import random
  4. import astroid
  5. from astroid import helpers
  6. from astroid import MANAGER
  7. ACCEPTED_ITERABLES_FOR_SAMPLE = (
  8. astroid.List,
  9. astroid.Set,
  10. astroid.Tuple,
  11. )
  12. def _clone_node_with_lineno(node, parent, lineno):
  13. cls = node.__class__
  14. other_fields = node._other_fields
  15. _astroid_fields = node._astroid_fields
  16. init_params = {
  17. 'lineno': lineno,
  18. 'col_offset': node.col_offset,
  19. 'parent': parent
  20. }
  21. postinit_params = {
  22. param: getattr(node, param)
  23. for param in _astroid_fields
  24. }
  25. if other_fields:
  26. init_params.update({
  27. param: getattr(node, param)
  28. for param in other_fields
  29. })
  30. new_node = cls(**init_params)
  31. if hasattr(node, 'postinit') and _astroid_fields:
  32. new_node.postinit(**postinit_params)
  33. return new_node
  34. def infer_random_sample(node, context=None):
  35. if len(node.args) != 2:
  36. raise astroid.UseInferenceDefault
  37. length = node.args[1]
  38. if not isinstance(length, astroid.Const):
  39. raise astroid.UseInferenceDefault
  40. if not isinstance(length.value, int):
  41. raise astroid.UseInferenceDefault
  42. inferred_sequence = helpers.safe_infer(node.args[0], context=context)
  43. if inferred_sequence in (None, astroid.Uninferable):
  44. raise astroid.UseInferenceDefault
  45. # TODO: might need to support more cases
  46. if not isinstance(inferred_sequence, ACCEPTED_ITERABLES_FOR_SAMPLE):
  47. raise astroid.UseInferenceDefault
  48. if length.value > len(inferred_sequence.elts):
  49. # In this case, this will raise a ValueError
  50. raise astroid.UseInferenceDefault
  51. try:
  52. elts = random.sample(inferred_sequence.elts, length.value)
  53. except ValueError:
  54. raise astroid.UseInferenceDefault
  55. new_node = astroid.List(
  56. lineno=node.lineno,
  57. col_offset=node.col_offset,
  58. parent=node.scope(),
  59. )
  60. new_elts = [
  61. _clone_node_with_lineno(
  62. elt,
  63. parent=new_node,
  64. lineno=new_node.lineno
  65. )
  66. for elt in elts
  67. ]
  68. new_node.postinit(new_elts)
  69. return iter((new_node, ))
  70. def _looks_like_random_sample(node):
  71. func = node.func
  72. if isinstance(func, astroid.Attribute):
  73. return func.attrname == 'sample'
  74. if isinstance(func, astroid.Name):
  75. return func.name == 'sample'
  76. return False
  77. MANAGER.register_transform(
  78. astroid.Call,
  79. astroid.inference_tip(infer_random_sample),
  80. _looks_like_random_sample,
  81. )