You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

astar.py 5.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import math
  2. from typing import Callable
  3. import re
  4. from graph import Graph, AdjacencyListGraph, AdjacencyMatrixGraph, NodeColor, Vertex
  5. import heapq
  6. def a_star(self, start_name: str, end_name: str, heuristic: Callable[[Vertex], float]):
  7. color_map = {} # maps vertices to their color
  8. distance_map = {} # maps vertices to their distance from the start vertex
  9. predecessor_map = {} # maps vertices to their predecessor in the traversal tree
  10. def cost(vertex: Vertex) -> float:
  11. """Compute the cost of the path to the given vertex."""
  12. if distance_map[vertex] is None:
  13. return math.inf
  14. return distance_map[vertex] + heuristic(vertex)
  15. Vertex.__lt__ = lambda self, other: cost(self) < cost(other)
  16. # Initialize the maps
  17. for vertex in self.all_vertices():
  18. color_map[vertex] = NodeColor.WHITE
  19. distance_map[vertex] = None
  20. predecessor_map[vertex] = None
  21. # Start at the given vertex
  22. start_node = self.get_vertex(start_name)
  23. color_map[start_node] = NodeColor.GRAY
  24. distance_map[start_node] = 0
  25. # Initialize the queue with the start vertex
  26. queue = [] # Use a list instead of a deque, because we need to sort it with
  27. heapq.heappush(queue, start_node)
  28. # Process the queue
  29. while len(queue) > 0:
  30. vertex = heapq.heappop(queue)
  31. if color_map[vertex] == NodeColor.BLACK:
  32. # Skip already processed vertices (possibly with a higher cost)
  33. continue
  34. if vertex.value == end_name:
  35. # Return the distance and predecessor maps
  36. return distance_map, predecessor_map
  37. for dest, weight in self.get_adjacent_vertices_with_weight(vertex.value):
  38. if color_map[dest] == NodeColor.BLACK:
  39. continue
  40. f = distance_map[vertex] + weight + heuristic(dest)
  41. if color_map[dest] == NodeColor.GRAY and f > cost(dest):
  42. continue
  43. predecessor_map[dest] = vertex
  44. distance_map[dest] = distance_map[vertex] + weight
  45. heapq.heappush(queue, dest)
  46. if color_map[dest] == NodeColor.WHITE:
  47. color_map[dest] = NodeColor.GRAY
  48. color_map[vertex] = NodeColor.BLACK
  49. # Return the distance and predecessor maps if no path was found
  50. return None, None
  51. # Add the a_star method to the Graph classes
  52. AdjacencyListGraph.a_star = a_star
  53. AdjacencyMatrixGraph.a_star = a_star
  54. if __name__ == "__main__":
  55. def read_labyrinth_into_graph(graph: Graph, filename: str):
  56. """Read a labyrinth from a file into a graph. The file format is a grid of characters:"""
  57. start = None
  58. end = None
  59. with open(filename, "r") as file:
  60. nodes = []
  61. lines = file.readlines()
  62. for y, line in enumerate(lines):
  63. for x, char in enumerate(line):
  64. if char in ' AS':
  65. name = pos_to_nodename((x, y))
  66. graph.insert_vertex(name)
  67. nodes.append((x, y))
  68. if char == 'A':
  69. end = (x, y)
  70. if char == 'S':
  71. start = (x, y)
  72. for x, y in nodes:
  73. name1 = f"x{x}y{y}"
  74. for neighbor in [(x - 1, y), (x, y - 1), (x, y + 1), (x + 1, y)]:
  75. if neighbor in nodes:
  76. name_neighbor = pos_to_nodename(neighbor)
  77. graph.connect(name1, name_neighbor, 1)
  78. return start, end, lines
  79. def nodename_to_pos(nodename):
  80. """Convert a node name to a position (x, y)."""
  81. m = re.match(r"(^x(\d*)y(\d*))", nodename)
  82. if m:
  83. return (int(m.group(2)), int(m.group(3)))
  84. return None
  85. def pos_to_nodename(pos):
  86. """Convert a position (x, y) to a node name."""
  87. x, y = pos
  88. return f"x{x}y{y}"
  89. def get_heuristic(end) -> Callable[[Vertex], float]:
  90. """Return a heuristic function for the given end position."""
  91. def heuristic(v: Vertex) -> float:
  92. x, y = nodename_to_pos(v.value)
  93. x1, y1 = end
  94. # Euclidean distance
  95. return math.sqrt(abs(x - x1)**2 + abs(y - y1)**2)
  96. return heuristic
  97. def update_lines(lines, distance_map, path):
  98. """Update the lines with the path found by the A* algorithm."""
  99. def replace_at(x, y, replacement):
  100. if lines[y][x] not in " .":
  101. return
  102. lines[y] = lines[y][:x] + replacement + lines[y][x+1:]
  103. for node in distance_map.keys():
  104. if distance_map[node] is not None:
  105. x, y = nodename_to_pos(node.value)
  106. replace_at(x, y, ".")
  107. for node in path:
  108. x, y = nodename_to_pos(node)
  109. replace_at(x, y, "*")
  110. return lines
  111. graph = AdjacencyListGraph()
  112. #graph = AdjacencyMatrixGraph()
  113. start, end, lines = read_labyrinth_into_graph(graph, "../../labyrinth.txt")
  114. distance_map, predecessor_map = graph.a_star(pos_to_nodename(start), pos_to_nodename(end), get_heuristic(end))
  115. endname = pos_to_nodename(end)
  116. lines = update_lines(lines, distance_map, graph.path(endname, predecessor_map))
  117. for line in lines:
  118. print(line, end="")