You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

plot_precision_tests.py 23KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. from localization import Localization
  2. from plot_results import Plotting
  3. from ART import ART_plotting
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import re
  7. import math
  8. mm = 1000 # millimetre
  9. localization = Localization()
  10. plotting = Plotting()
  11. fig = plt.figure(figsize=[24, 9])
  12. xy_subplot = plt.subplot2grid((1, 2), (0, 1))
  13. yz_subplot = plt.subplot2grid((1, 2), (0, 0))
  14. # Load all the IndLoc Recordings
  15. center = np.load("recorded_data/center.npy")[1:, :, :]
  16. off_center = np.load("recorded_data/off_center.npy")[1:, :, :]
  17. horizontal = np.load("recorded_data/horizontal.npy")[1:, :, :]
  18. diagonal = np.load("recorded_data/diagonal.npy")[1:, :, :]
  19. zigzag_1 = np.load("recorded_data/zigzag_1.npy")
  20. zigzag_2 = np.load("recorded_data/zigzag_2.npy")
  21. wave_1 = np.load("recorded_data/wave_1.npy")
  22. wave_2 = np.load("recorded_data/wave_2.npy")
  23. spiral_1 = np.load("recorded_data/spiral_1.npy")
  24. half_current_horizontal = np.load("recorded_data/half_current_horizontal.npy")[1:, :, :]
  25. half_current_diagonal = np.load("recorded_data/half_current_diagonal.npy")[1:, :, :]
  26. double_current_horizontal = np.load("recorded_data/double_current_horizontal.npy")[1:, :, :]
  27. double_current_diagonal = np.load("recorded_data/half_current_diagonal.npy")[1:, :, :]
  28. z_axis = np.load("recorded_data/z_axis.npy")[1:, :, :]
  29. z_axis_off_center = np.load("recorded_data/z_axis_off_center.npy")[1:, :, :]
  30. Three_D_movement_1 = np.load("recorded_data/3D_movement_1.npy")
  31. Three_D_movement_2 = np.load("recorded_data/3D_movement_2.npy")
  32. Three_D_zigzag_1 = np.load("recorded_data/3D_zigzag_1.npy")
  33. Three_D_zigzag_2 = np.load("recorded_data/3D_zigzag_2.npy")
  34. z_axis_moving_1 = np.load("recorded_data/z_axis_moving_1.npy")
  35. # z_axis_moving_2 = np.load("recorded_data/z_axis_moving_2.npy") # Recording missing?
  36. z_wave_1 = np.load("recorded_data/z_wave_1.npy")
  37. z_wave_2 = np.load("recorded_data/z_wave_2.npy")
  38. recordings_dict = {"center": center, # shape: (26, 2000, 19) = [P1-P25, 2k samples, fv]
  39. "off_center": off_center,
  40. "horizontal": horizontal,
  41. "diagonal": diagonal,
  42. "horizontal_and_diagonal": np.append(horizontal, diagonal, axis=0),
  43. "zigzag_1": zigzag_1,
  44. "zigzag_2": zigzag_2,
  45. "wave_1": wave_1,
  46. "wave_2": wave_2,
  47. "spiral_1": spiral_1,
  48. "half_current_horizontal": half_current_horizontal,
  49. "half_current_diagonal": half_current_diagonal,
  50. "double_current_horizontal": double_current_horizontal,
  51. "double_current_diagonal": double_current_diagonal,
  52. "z_axis": z_axis,
  53. "z_axis_off_center": z_axis_off_center,
  54. "Three_D_movement_1": Three_D_movement_1,
  55. "Three_D_movement_2": Three_D_movement_2,
  56. "Three_D_zigzag_1": Three_D_zigzag_1,
  57. "Three_D_zigzag_2": Three_D_zigzag_2,
  58. "z_axis_moving_1": z_axis_moving_1,
  59. "z_wave_1": z_wave_1,
  60. "z_wave_2": z_wave_2}
  61. # this only takes the first sample of each recording
  62. recording_dict_fs = {"center": center[:, 0, :],
  63. "off_center": off_center[:, 0, :],
  64. "horizontal": horizontal[:, 0, :],
  65. "diagonal": diagonal[:, 0, :],
  66. "horizontal_and_diagonal": recordings_dict["horizontal_and_diagonal"][:, 0, :],
  67. "zigzag_1": zigzag_1[0, :],
  68. "zigzag_2": zigzag_2[0, :],
  69. "wave_1": wave_1[0, :],
  70. "wave_2": wave_2[0, :],
  71. "spiral_1": spiral_1[0, :],
  72. "half_current_horizontal": half_current_horizontal[:, 0, :],
  73. "half_current_diagonal": half_current_diagonal[:, 0, :],
  74. "double_current_horizontal": double_current_horizontal[:, 0, :],
  75. "double_current_diagonal": double_current_diagonal[:, 0, :],
  76. "z_axis": z_axis[:, 0, :],
  77. "z_axis_off_center": z_axis_off_center[:, 0, :],
  78. "Three_D_movement_1": Three_D_movement_1[0, :],
  79. "Three_D_movement_2": Three_D_movement_2[0, :],
  80. "Three_D_zigzag_1": Three_D_zigzag_1[0, :],
  81. "Three_D_zigzag_2": Three_D_zigzag_2[0, :],
  82. "z_axis_moving_1": z_axis_moving_1[0, :],
  83. "z_wave_1": z_wave_1[0, :],
  84. "z_wave_2": z_wave_2[0, :]}
  85. def plot_surroundings():
  86. font = {"family": "serif", "color": "black", "size": 15}
  87. plt.subplots_adjust(left=0.07, bottom=0.07, right=0.89, top=0.89, wspace=0.2, hspace=0.2)
  88. # Set Axis
  89. xy_subplot.set_xlim([-0.02 * mm, 0.52 * mm])
  90. xy_subplot.set_ylim([-0.02 * mm, 0.52 * mm])
  91. # xy_subplot.set_xticks(np.arange(0.0*mm, 0.55*mm, 0.05*mm))
  92. # xy_subplot.set_yticks(np.arange(0.0*mm, 0.55*mm, 0.05*mm))
  93. xy_subplot.set_xlabel('X [mm]', fontdict=font)
  94. xy_subplot.set_ylabel('Y [mm]', fontdict=font)
  95. xy_subplot.grid(alpha=0.6)
  96. # for label in xy_subplot.xaxis.get_ticklabels():
  97. # label.set_rotation(90)
  98. # Set Axis
  99. yz_subplot.set_ylim([-0.02 * mm, 0.52 * mm])
  100. yz_subplot.set_xlim([-0.02 * mm, 0.22 * mm])
  101. # yz_subplot.set_ylim([0.1 * mm, 0.5 * mm])
  102. # yz_subplot.set_xlim([0.0 * mm, 0.2 * 100])
  103. # yz_subplot.set_yticks(np.arange(0.00 * mm, 0.5 * mm, 0.05 * mm))
  104. # yz_subplot.set_xticks(np.arange(0.00 * mm, 0.20 * mm, 0.01 * mm))
  105. xy_subplot.tick_params(labelsize=13)
  106. yz_subplot.tick_params(labelsize=13)
  107. yz_subplot.set_ylabel('Y [mm]', fontdict=font)
  108. yz_subplot.set_xlabel('Z [mm]', fontdict=font)
  109. yz_subplot.grid(alpha=0.5)
  110. # Set Titles
  111. xy_subplot.set_title("XY", fontdict=font)
  112. yz_subplot.set_title("ZY", fontdict=font)
  113. # CONFIGURE X, Y PLOT
  114. # Indicate Exciter
  115. Exciter = xy_subplot.vlines(x=0.0 * mm, ymin=0.0 * mm, ymax=0.5 * mm, linestyles="-", colors="#5b9bd5", label="Exciter",
  116. linewidth=6)
  117. xy_subplot.vlines(x=0.5 * mm, ymin=0.0 * mm, ymax=0.5 * mm, linestyles="-", colors="#5b9bd5", linewidth=6)
  118. xy_subplot.hlines(y=0.0 * mm, xmin=0.0 * mm, xmax=0.5 * mm, linestyles="-", colors="#5b9bd5", linewidth=6)
  119. xy_subplot.hlines(y=0.5 * mm, xmin=0.0 * mm, xmax=0.5 * mm, linestyles="-", colors="#5b9bd5", linewidth=6)
  120. yz_subplot.vlines(x=0.0, ymin=0.0 * mm, ymax=0.5 * mm, linestyles="-", colors="#5b9bd5", linewidth=6, alpha=0.3)
  121. # Indicate Localization Area
  122. # LocArea = xy_subplot.vlines(x=0.045, ymin=0.045, ymax=0.495, linestyles="dashed", colors="#7f7f7f",
  123. # label="Loc.\nArea")
  124. # xy_subplot.vlines(x=0.495, ymin=0.045, ymax=0.495, linestyles="dashed", colors="#7f7f7f")
  125. # xy_subplot.hlines(y=0.045, xmin=0.045, xmax=0.495, linestyles="dashed", colors="#7f7f7f")
  126. # xy_subplot.hlines(y=0.495, xmin=0.045, xmax=0.495, linestyles="dashed", colors="#7f7f7f")
  127. # SHOW ANTENNAS
  128. AntPosX = [0 * mm, 0 * mm, 0.13 * mm, 0.375 * mm, 0.5 * mm, 0.5 * mm, 0.125 * mm,
  129. 0.38 * mm] # Safing each Antennas Position in a list -> e.g. : X position of frame ant 1, X pos of frame ant 2, x pos of frame ant 3, ...
  130. AntPosY = [0.125 * mm, 0.375 * mm, 0.5 * mm, 0.5 * mm, 0.125 * mm, 0.375 * mm, 0 * mm, 0 * mm]
  131. AntOrientY = [270, 270, 180, 180, 90, 90, 0,
  132. 0] # analog as AntPosX,Y , just with Orientation degrees (direction in z-axis
  133. # Select Antenna plotting options here
  134. patchAlpha = 1.0
  135. # ITERATE THROUGH CMAP-------------
  136. n = 8
  137. # color = iter(cm.jet(np.linspace(0, 1, n)))
  138. # ----------------------------
  139. for i in range(8): # iterate through antennas
  140. x = AntPosX[i] # get x and y pos of antennas
  141. y = AntPosY[i]
  142. # c = next(color) # Comment in for coloured antennas
  143. c = "k" # black antennas
  144. if i == 7: # only print one label for the Antennas (otherwise its 8 times in legend)
  145. label = label = str("Receiving\nCoils")
  146. else:
  147. label = ""
  148. if AntOrientY[i] == 0:
  149. # identify antennas orientation
  150. patchX = float(x - 0.033 * mm / 2)
  151. patchY = float(y - 0.033 * mm / 2)
  152. rect = plt.Rectangle((patchX, patchY), 0.033 * mm, 0.033 * mm,
  153. color=c, alpha=patchAlpha, label=label)
  154. xy_subplot.add_patch(rect)
  155. if AntOrientY[i] == 180:
  156. patchX = float(x - 0.033 * mm / 2)
  157. patchY = float(y - 0.033 * mm / 2)
  158. rect = plt.Rectangle((patchX, patchY), 0.033 * mm, 0.033 * mm,
  159. color=c, alpha=patchAlpha)
  160. xy_subplot.add_patch(rect)
  161. if AntOrientY[i] == 90:
  162. patchX = float(x - 0.033 * mm / 2)
  163. patchY = float(y - 0.033 * mm / 2)
  164. rect = plt.Rectangle((patchX, patchY), 0.033 * mm, 0.033 * mm,
  165. color=c, alpha=patchAlpha)
  166. xy_subplot.add_patch(rect)
  167. if AntOrientY[i] == 270:
  168. patchX = float(x - 0.033 * mm / 2)
  169. patchY = float(y - 0.033 * mm / 2)
  170. rect = plt.Rectangle((patchX, patchY), 0.033 * mm, 0.033 * mm,
  171. color=c, alpha=patchAlpha)
  172. xy_subplot.add_patch(rect)
  173. # Add antenna patches in yz plot
  174. # rect2 = plt.Rectangle((patchX, patchY), 0.033 * mm, 0.033 * mm,
  175. # color=c, alpha=0.3, label=label)
  176. # yz_subplot.add_patch(rect2)
  177. def art_file_to_position(recording_name):
  178. """
  179. :param recording_name: int, Choose which recording to print e.g. Recording_!1! , Recording_!2!, etc.
  180. """
  181. """--------------- Open, format and save ART positions from text files---------------------------------------"""
  182. # Opening a recording
  183. art_recording = open('ART Recordings (no ferrite)/' + str(recording_name) + '.drf', 'r')
  184. art_recording = art_recording.read()
  185. # Formatting a recording
  186. art_recording = re.split(' |\n', art_recording) # split string at whitespace and linebreak
  187. art_recording = np.array(art_recording) # list -> np.array
  188. # Count the total number of recorded positions
  189. total_pos_counter = 0
  190. for i in art_recording:
  191. if i == '6d': # this is the flag marking a 6d position (Syntax of the ART guys in the txt files)
  192. total_pos_counter += 1
  193. # Create an empty array which will store the ART measurement infos
  194. art_info = np.zeros((total_pos_counter, 7)) # [t, x, y, z, alpha, beta, gamma]
  195. # Fill the empty numpy array with the information's of the art_recording
  196. pos_counter = 0
  197. for i in range(len(art_recording) - 9):
  198. if art_recording[i] == 'ts': # following entry is timestamp
  199. art_info[pos_counter] = float(art_recording[i + 1])
  200. if art_recording[i] == '6d': # following entries are x,y,z,alpha,beta,gamma
  201. art_info[pos_counter, 1] = float(art_recording[i + 3].split('[')[1]) # x (had to delete some excess text)
  202. art_info[pos_counter, 2] = float(art_recording[i + 4]) # y
  203. art_info[pos_counter, 3] = float(art_recording[i + 5]) # z
  204. art_info[pos_counter, 4] = float(art_recording[i + 6]) # alpha
  205. art_info[pos_counter, 5] = float(art_recording[i + 7]) # beta
  206. art_info[pos_counter, 6] = float(
  207. art_recording[i + 8].split(']')[0]) # gamma (had to delete some excess text
  208. pos_counter += 1
  209. # print("x=", art_info[:, 1])
  210. # print("y=", art_info[:, 2])
  211. pos = [art_info[0, 1], art_info[0, 2], art_info[0, 3]] # only take the first x,y,z position of each recording (as the 0.33mm precision wont have a lot of noise anyways)
  212. return pos
  213. def plot_art_positions(positions, annotate_flag, markersize, annotation_size):
  214. for i in range(len(positions)):
  215. if i == 0: # only add the label to one of the points
  216. xy_subplot.scatter(positions[i, 0], positions[i, 1], color="green", label="ART System", s=markersize, marker="x")
  217. yz_subplot.scatter(positions[i, 2], positions[i, 1], color="green", label="ART System", s=markersize, marker="x")
  218. else:
  219. xy_subplot.scatter(positions[i, 0], positions[i, 1], color="green", s=markersize, marker="x")
  220. yz_subplot.scatter(positions[i, 2], positions[i, 1], color="green", s=markersize, marker="x")
  221. #print("ART: x=", art_info[0, 1], ", y=", art_info[0, 2], ", z=", art_info[0, 3])
  222. if i == len(positions)-1:
  223. if annotate_flag:
  224. xy_subplot.annotate("P13", (positions[i, 0], positions[i, 1]), color="green", size=annotation_size)
  225. yz_subplot.annotate("P13", (positions[i, 2], positions[i, 1]), color="green", size=annotation_size)
  226. def calc_multiple_positions(data, output_filename, only_fs, sf_list, k_list):
  227. """
  228. Takes a recording name and calculates a lot of positions for it, for each given scaling_factor and k_nearest_factor
  229. and stores them all as numpy arrays in folder "calculated_positions/..."
  230. :param recording_name: string name of recording (see recordings_dict)
  231. :param only_fs: only take first sample of recording
  232. :param sf_list: list with all scaling factors to localize with
  233. :param k_list: list with all k_nearest factors to localize with
  234. :param output_filename: it saves the positions under this name. It adds stuff like the scaling factor , k-nearest automatically
  235. :return returns an array containing all positions [3, len(sf_list)*len(k_list)]
  236. """
  237. positions = np.zeros((len(sf_list)*len(k_list), np.shape(data)[0], 3)) # shape: [1000, 42, 3] = [Combinations of sf and k, P1-P42, 3 xyz]
  238. i = 0
  239. j = 0
  240. print("----Calculating positions for ", len(sf_list) * len(k_list), " combinations of sf and k-----")
  241. for sf in sf_list:
  242. localization.scale_factor = sf
  243. for k in k_list:
  244. localization.k_nearest = k
  245. positions[(i+j), :, :] = localization.localize_all_samples_direct(data, output_filename + "_sf=" + str(sf) + "_k=" + str(k))
  246. if len(k_list) != 1: # only increment this if we have multiple k_nearests getting tested
  247. j += 1
  248. if len(sf_list) != 1: # only increment i if we have mutliple scaling factors getting tested
  249. i += 1
  250. print("progress ", i, "/", (len(sf_list)*len(k_list)))
  251. print("----Finished all position calculations-----")
  252. return positions
  253. def plot_multiple_positions(positions_filename, sf_list, k_list, only_label_once, xy_title, yz_title, annotate_points, labels, color_all_black, markersize, annotationsize):
  254. """
  255. You can give this function the name of an IndLoc recording and a bunch of scaling factors and k-nearests
  256. and it will plot you all of them.
  257. :param recording_name:
  258. :param sf_list:
  259. :param k_list:
  260. :param only_label_once:
  261. :param xy_title:
  262. :param yz_title:
  263. :param annotate_points:
  264. :param labels:
  265. :param color_all_black:
  266. :return:
  267. """
  268. colors = iter(plt.cm.jet(np.linspace(0, 1, len(sf_list)*len(k_list))))
  269. #markers = iter(['o', '8', 'p', '>', 'd', 'H', 'x','v', '^','*', 'D','s', 'h', '+', '<'])
  270. markers = iter(['x', 'x'])
  271. markers = iter(['.', '.'])
  272. #colors = iter(["blue", "orange", "red"])
  273. for sf in sf_list:
  274. for k in k_list:
  275. color = next(colors)
  276. marker = next(markers)
  277. positions = np.load("calculated_positions/"+positions_filename + "_sf=" + str(sf) + "_k=" + str(k)+".npy")
  278. # print("Plotting positions for: sf=", sf, ", k=", k, positions) # show all positions
  279. print("Plotting positions for: sf=", sf, ", k=", k) # minimal print view
  280. if color_all_black:
  281. xy_subplot.scatter(positions[:, 0], positions[:, 1], marker=marker, color="black", label=next(labels), s=markersize)
  282. yz_subplot.scatter(positions[:, 2], positions[:, 1], marker=marker, color="black", s=markersize)
  283. if annotate_points:
  284. for i in range(len(positions[:, 0])):
  285. xy_subplot.annotate("P" + str(i + 1), (positions[i, 0], positions[i, 1]), color="black", size=annotationsize)
  286. yz_subplot.annotate("P" + str(i + 1), (positions[i, 2], positions[i, 1]), color="black", size=annotationsize)
  287. else:
  288. xy_subplot.scatter(positions[:, 0], positions[:, 1], marker=marker, facecolor=color, label=next(labels))
  289. yz_subplot.scatter(positions[:, 2], positions[:, 1], marker=marker, facecolor=color)
  290. if annotate_points:
  291. for i in range(len(positions[:, 0])):
  292. xy_subplot.annotate("P" + str(i + 1), (positions[i, 0], positions[i, 1]), color=color, size=annotationsize)
  293. yz_subplot.annotate("P" + str(i + 1), (positions[i, 2], positions[i, 1]), color=color, size=annotationsize)
  294. yz_subplot.set_title(str(xy_title), size=20)
  295. xy_subplot.set_title(str(yz_title), size=20)
  296. def eval_loc(IndLoc_pos, ART_pos):
  297. """
  298. Takes 2 arrays full of x,y,z positions. Calculated the distance between the two for each position.
  299. Then calculates the distance for each point. Then the mean_distance and the standard_deviation of all positions together.
  300. :param IndLoc_pos: np.array[i, 3] (a number of positions: x,y,z)
  301. :param ART_pos: np.array[i, 3] (a number of positions: x,y,z)
  302. :return: distances, mean_distance, std_dev_distance
  303. """
  304. if np.shape(IndLoc_pos) != np.shape(ART_pos):
  305. print("During calc distance function. Positions arrays dont have the same shape. \n IndLoc (shape):",
  306. np.shape(IndLoc_pos), " ART (shape):", np.shape(ART_pos))
  307. else:
  308. distances = np.zeros((np.shape(IndLoc_pos)[0], 1)) # create an array containing all distances
  309. for i in range(np.shape(IndLoc_pos)[0]): # iterate through all positions
  310. x_IndLoc = IndLoc_pos[i, 0]
  311. y_IndLoc = IndLoc_pos[i, 1]
  312. z_IndLoc = IndLoc_pos[i, 2]
  313. x_ART = ART_pos[i, 0]
  314. y_ART = ART_pos[i, 1]
  315. z_ART = ART_pos[i, 2]
  316. distance = math.sqrt(
  317. (x_IndLoc - x_ART) ** 2 + (y_IndLoc - y_ART) ** 2) # calculate difference for each position
  318. distances[i] = distance
  319. # print("i", i)
  320. # print("IndLoc: x=", x_IndLoc, ", y=", y_IndLoc)
  321. # print("ART: x=", x_ART, ", y=", y_ART)
  322. mean_distance = np.mean(distances)
  323. std_dev_distance = np.std(distances)
  324. return distances, round(mean_distance, 2), round(std_dev_distance, 2)
  325. def load_entire_ART(recording_names):
  326. """
  327. :param recording_names: list of all recording that are to be loaded ["center", "horizontal", "diagonal"]
  328. :return: an array containing them all appended together
  329. """
  330. x_list = []
  331. y_list = []
  332. z_list = []
  333. for i in range(len(recording_names)): # iterate through list of wanted recordings
  334. print("Loading ART", recording_names[i])
  335. for j in range(1, 30): # it always starts with "P1 and ends variable but never over P30"
  336. try:
  337. current_pos = art_file_to_position(recording_names[i]+"_P"+(str(j)))
  338. x_list.append(current_pos[0])
  339. y_list.append(current_pos[1])
  340. z_list.append(current_pos[2])
  341. except FileNotFoundError:
  342. break
  343. pos = np.zeros((len(x_list), 3))
  344. pos[:, 0] = x_list
  345. pos[:, 1] = y_list
  346. pos[:, 2] = z_list
  347. return pos
  348. # Normal ART Loading (1 Recording)
  349. # for i in range(1, 26):
  350. # print("i", i)
  351. # # if i <= 21:
  352. # # pos = plot_art("horizontal_P" + str(i), annotate_point=False)
  353. # if i == 26:
  354. # label_flag = True
  355. # pos = plot_art(recording_selection+"_P" + str(i), annotate_point=False)
  356. # x_list.append(pos[0])
  357. # y_list.append(pos[1])
  358. # z_list.append(pos[2])
  359. # Input here -----------------------------------------------------------------------------------------------------------
  360. recording_selection = ["z_axis"]
  361. sf_list = [0.0021] # np.arange(0.0015, 0.0025, 0.0001)
  362. k_list = [12]
  363. annotate_points = False
  364. plot_title = "Noise investigation"
  365. label = "IndLoc"
  366. # End of input ---------------------------------------------------------------------------------------------------------
  367. # Make the plot look nice
  368. plot_surroundings()
  369. # Load data positions
  370. if len(recording_selection) == 1: # if we only have to look at simply load ART and IndLoc
  371. ART_pos = load_entire_ART(recording_selection[0]) # Load ART positions
  372. IndLoc_data = recording_dict_fs["z_axis"] # Load IndLoc data
  373. if len(recording_selection) == 2: # if we want to look at multiple recordings its a little more complicated
  374. ART_pos = load_entire_ART([recording_selection[0], recording_selection[1]]) # Load both ART positions
  375. rec1 = recording_dict_fs[recording_selection[0]] # Load IndLoc1
  376. rec2 = recording_dict_fs[recording_selection[1]] # Load IndLoc2
  377. IndLoc_data = np.append(rec1, rec2, axis=0) # Append both IndLoc arrays correctly
  378. # Average IndLoc data
  379. #IndLoc_data = np.average(IndLoc_data, axis=1) # shape (42, 19) = (P1-P42, fv_avg)
  380. # Localize IndLoc data for each scaling factor in sf_list and each k in k_list
  381. IndLoc_pos = calc_multiple_positions(data=IndLoc_data,
  382. output_filename=str(recording_selection),
  383. only_fs=True,
  384. sf_list=sf_list,
  385. k_list=k_list)
  386. # Find the best localization parameters
  387. # best_mean = 10000
  388. # best_std_dev = 10000
  389. # best_both = 10000
  390. # for i in range((len(sf_list)*len(k_list))):
  391. # print(i)
  392. # distances, curr_mean, curr_std_dev = eval_loc(ART_pos=ART_pos, IndLoc_pos=IndLoc_pos[i])
  393. # #print("- current k_nearest:", sf_list[i], " --> mean=", curr_mean, ", std_dev=",curr_std_dev)
  394. # curr_both = curr_mean + curr_std_dev
  395. #
  396. # if curr_both <= best_both:
  397. # best_both = curr_both
  398. # best_mean = curr_mean
  399. # best_std_dev = curr_std_dev
  400. # index = i
  401. #
  402. # if len(sf_list) != 1: # it looks like we`re analyzing multiple sf`s
  403. # best_k = k_list[0]
  404. # best_sf = sf_list[index]
  405. # if len(k_list) != 1:
  406. # best_sf = sf_list[0]
  407. # best_k = k_list[index]
  408. # if (len(sf_list) == 1) and (len(k_list) == 1):
  409. # best_sf = sf_list[0]
  410. # best_k = k_list[0]
  411. #
  412. # print("\n----Best localization parameters were found!-----\n scaling_factor = ", best_sf, "\n mean_error = ", best_mean,
  413. # "\n standard_deviation = ", best_std_dev)
  414. # Plot IndLoc pos
  415. plot_multiple_positions(positions_filename=str(recording_selection),
  416. sf_list=[0.001755],
  417. k_list=[19],
  418. only_label_once=True,
  419. xy_title=plot_title,
  420. yz_title=plot_title,
  421. annotate_points=annotate_points,
  422. annotationsize=15,
  423. labels=iter([label]), # labels=iter(["k = " + str(best_k)]),
  424. color_all_black=True,
  425. markersize=10,)
  426. # Plot ART positions
  427. plot_art_positions(ART_pos,
  428. annotate_flag=True,
  429. markersize=100,
  430. annotation_size=15)
  431. xy_subplot.legend(bbox_to_anchor=(1.01, 1.00), loc=2, borderaxespad=0., fontsize=10)
  432. plt.show()