You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

astar.py 4.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import math
  2. from typing import Callable
  3. import re
  4. from graph import Graph, AdjacencyListGraph, AdjacencyMatrixGraph, NodeColor, Vertex
  5. def a_star(self, start_name: str, end_name: str, heuristic: Callable[[Vertex], float]):
  6. color_map = {} # maps vertices to their color
  7. distance_map = {} # maps vertices to their distance from the start vertex
  8. predecessor_map = {} # maps vertices to their predecessor in the traversal tree
  9. def cost(vertex: Vertex) -> float:
  10. """Compute the cost of the path to the given vertex."""
  11. if distance_map[vertex] is None:
  12. return math.inf
  13. return distance_map[vertex] + heuristic(vertex)
  14. # Initialize the maps
  15. for vertex in self.all_vertices():
  16. color_map[vertex] = NodeColor.WHITE
  17. distance_map[vertex] = None
  18. predecessor_map[vertex] = None
  19. # Start at the given vertex
  20. start_node = self.get_vertex(start_name)
  21. color_map[start_node] = NodeColor.GRAY
  22. distance_map[start_node] = 0
  23. # Initialize the queue with the start vertex
  24. queue = [start_node]
  25. # Process the queue
  26. while len(queue) > 0:
  27. queue.sort(key=cost)
  28. vertex = queue.pop(0)
  29. if vertex.value == end_name:
  30. # Return the distance and predecessor maps
  31. return distance_map, predecessor_map
  32. for dest, weight in self.get_adjacent_vertices_with_weight(vertex.value):
  33. if color_map[dest] == NodeColor.BLACK:
  34. continue
  35. f = distance_map[vertex] + weight + heuristic(dest)
  36. if color_map[dest] == NodeColor.GRAY and f > cost(dest):
  37. continue
  38. predecessor_map[dest] = vertex
  39. distance_map[dest] = distance_map[vertex] + weight
  40. if color_map[dest] == NodeColor.WHITE:
  41. queue.append(dest)
  42. color_map[dest] = NodeColor.GRAY
  43. color_map[vertex] = NodeColor.BLACK
  44. # Return the distance and predecessor maps if no path was found
  45. return None, None
  46. # Add the a_star method to the Graph classes
  47. AdjacencyListGraph.a_star = a_star
  48. AdjacencyMatrixGraph.a_star = a_star
  49. if __name__ == "__main__":
  50. def read_labyrinth_into_graph(graph: Graph, filename: str):
  51. """Read a labyrinth from a file into a graph. The file format is a grid of characters:"""
  52. start = None
  53. end = None
  54. with open(filename, "r") as file:
  55. nodes = []
  56. lines = file.readlines()
  57. for y, line in enumerate(lines):
  58. for x, char in enumerate(line):
  59. if char in ' AS':
  60. name = pos_to_nodename((x, y))
  61. graph.insert_vertex(name)
  62. nodes.append((x, y))
  63. if char == 'A':
  64. end = (x, y)
  65. if char == 'S':
  66. start = (x, y)
  67. for x, y in nodes:
  68. name1 = f"x{x}y{y}"
  69. for neighbor in [(x - 1, y), (x, y - 1), (x, y + 1), (x + 1, y)]:
  70. if neighbor in nodes:
  71. name_neighbor = pos_to_nodename(neighbor)
  72. graph.connect(name1, name_neighbor, 1)
  73. return start, end, lines
  74. def nodename_to_pos(nodename):
  75. """Convert a node name to a position (x, y)."""
  76. m = re.match(r"(^x(\d*)y(\d*))", nodename)
  77. if m:
  78. return (int(m.group(2)), int(m.group(3)))
  79. return None
  80. def pos_to_nodename(pos):
  81. """Convert a position (x, y) to a node name."""
  82. x, y = pos
  83. return f"x{x}y{y}"
  84. def get_heuristic(end) -> Callable[[Vertex], float]:
  85. """Return a heuristic function for the given end position."""
  86. def heuristic(v: Vertex) -> float:
  87. x, y = nodename_to_pos(v.value)
  88. x1, y1 = end
  89. # Euclidean distance
  90. return math.sqrt(abs(x - x1)**2 + abs(y - y1)**2)
  91. return heuristic
  92. def update_lines(lines, distance_map, path):
  93. """Update the lines with the path found by the A* algorithm."""
  94. def replace_at(x, y, replacement):
  95. if lines[y][x] not in " .":
  96. return
  97. lines[y] = lines[y][:x] + replacement + lines[y][x+1:]
  98. for node in distance_map.keys():
  99. if distance_map[node] is not None:
  100. x, y = nodename_to_pos(node.value)
  101. replace_at(x, y, ".")
  102. for node in path:
  103. x, y = nodename_to_pos(node)
  104. replace_at(x, y, "*")
  105. return lines
  106. graph = AdjacencyListGraph()
  107. #graph = AdjacencyMatrixGraph()
  108. start, end, lines = read_labyrinth_into_graph(graph, "../../labyrinth.txt")
  109. distance_map, predecessor_map = graph.a_star(pos_to_nodename(start), pos_to_nodename(end), get_heuristic(end))
  110. endname = pos_to_nodename(end)
  111. lines = update_lines(lines, distance_map, graph.path(endname, predecessor_map))
  112. for line in lines:
  113. print(line, end="")