In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete Methode des 3. Platzes der ISBI2019.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

plot.py 4.7KB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. from glob import glob
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from os.path import join
  5. from scipy.stats import mannwhitneyu
  6. dataroots = {
  7. 'PROPOSAL' : 'results',
  8. #'model_cnmc_res_128' : 'results/model_cnmc_res_128',
  9. #'model_cnmc_res_224' : 'results/model_cnmc_res_224',
  10. #'model_cnmc_res_256' : 'results/model_cnmc_res_256',
  11. #'model_cnmc_res_450' : 'results/model_cnmc_res_450',
  12. #'model_cnmc_res_450_blue_only' : 'results/model_cnmc_res_450_blue_only',
  13. #'model_cnmc_res_450_green_only' : 'results/model_cnmc_res_450_green_only',
  14. #'model_cnmc_res_450_red_only' : 'results/model_cnmc_res_450_red_only',
  15. #'model_cnmc_res_450_no_blue' : 'results/model_cnmc_res_450_no_blue',
  16. #'model_cnmc_res_450_no_green' : 'results/model_cnmc_res_450_no_green',
  17. #'model_cnmc_res_450_no_red' : 'results/model_cnmc_res_450_no_red',
  18. #'model_cnmc_res_450_grayscale' : 'results/model_cnmc_res_450_grayscale',
  19. }
  20. def get_values(dataroot, key):
  21. npzs = list(glob(join(dataroot, '*', 'results.npz')))
  22. vals = []
  23. for f in npzs:
  24. recorded_data = np.load(f)
  25. val = recorded_data[key]
  26. vals.append(val)
  27. vals = np.stack(vals, 0)
  28. return vals
  29. def plot_mean_std(dataroot, key, ax, **kwargs):
  30. vals = get_values(dataroot, key)
  31. mean = np.mean(vals, 0)
  32. std = np.std(vals, 0)
  33. epochs = np.arange(len(mean))
  34. # Offset by 1 so that we have nicely zoomed plots
  35. mean = mean[1:]
  36. std = std[1:]
  37. epochs = epochs[1:]
  38. ax.plot(epochs, mean, **kwargs)
  39. ax.fill_between(epochs, mean - std, mean + std, alpha=0.2)
  40. def plot3(key, ax):
  41. for k, v in dataroots.items():
  42. plot_mean_std(v, key, ax, label=k)
  43. def print_final_min_mean_max(dataroot, key, model_epochs):
  44. vals = get_values(dataroot, key) * 100
  45. vals = vals[np.arange(len(vals)), model_epochs]
  46. min = np.min(vals)
  47. mean = np.mean(vals)
  48. std = np.std(vals)
  49. max = np.max(vals)
  50. print(f'{min:.2f}', f'{mean:.2f} ± {std:.2f}', f'{max:.2f}', sep='\t')
  51. def print_final_table(dataroot):
  52. best_model_epochs = np.argmax(get_values(dataroot, 'f1'), axis=1)
  53. print_final_min_mean_max(dataroot, 'acc', best_model_epochs)
  54. print_final_min_mean_max(dataroot, 'acc_all', best_model_epochs)
  55. print_final_min_mean_max(dataroot, 'acc_hem', best_model_epochs)
  56. print_final_min_mean_max(dataroot, 'f1', best_model_epochs)
  57. print_final_min_mean_max(dataroot, 'precision', best_model_epochs)
  58. print_final_min_mean_max(dataroot, 'recall', best_model_epochs)
  59. def get_best_f1_scores(dataroot):
  60. f1_scores = get_values(dataroot, 'f1')
  61. best_model_epochs = np.argmax(f1_scores, axis=1)
  62. return f1_scores[np.arange(len(f1_scores)), best_model_epochs]
  63. def is_statistically_greater(dataroot1, dataroot2):
  64. # Tests if F1-score of dataroot1 is greater than dataroot2
  65. a = get_best_f1_scores(dataroot1)
  66. b = get_best_f1_scores(dataroot2)
  67. u, p = mannwhitneyu(a, b, alternative='greater')
  68. return u, p
  69. ######
  70. for k, v in dataroots.items():
  71. print(k)
  72. print_final_table(v)
  73. print()
  74. ######
  75. #print("MWU-Test of PROPOSAL > NOSPECLR")
  76. #print(is_statistically_greater(dataroots['PROPOSAL'], dataroots['NOSPECLR']))
  77. #print()
  78. #print("MWU-Test of PROPOSAL > NOROT")
  79. #print(is_statistically_greater(dataroots['PROPOSAL'], dataroots['NOROT']))
  80. ######
  81. fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(9, 5))
  82. ax[0, 0].set_title('Accuracy')
  83. plot3('acc', ax[0, 0])
  84. ax[0, 1].set_title('Sensitivity')
  85. plot3('acc_all', ax[0, 1])
  86. ax[0, 2].set_title('Specificity')
  87. plot3('acc_hem', ax[0, 2])
  88. ax[1, 0].set_title('F1 score')
  89. plot3('f1', ax[1, 0])
  90. ax[1, 1].set_title('Precision')
  91. plot3('precision', ax[1, 1])
  92. ax[1, 2].set_title('Recall')
  93. plot3('recall', ax[1, 2])
  94. fig.legend(loc='lower center', ncol=3)
  95. fig.tight_layout()
  96. fig.subplots_adjust(bottom=0.12)
  97. fig.savefig('results/plot_ablations.pdf')
  98. ######
  99. npload= 'results/model_cnmc_res_128'
  100. npload_sub=npload + '/subj_acc.npz'
  101. npload_res=npload + '/results.npz'
  102. subj_acc = np.load(npload_sub)
  103. subj = list(sorted(subj_acc.keys()))
  104. acc = [subj_acc[k] for k in subj]
  105. fig, ax = plt.subplots(figsize=(9, 2))
  106. ax.bar(range(len(acc)), acc, width=0.3, tick_label=subj)
  107. fig.tight_layout()
  108. fig.savefig('results/plot_subj_acc.pdf')
  109. ######
  110. data = np.load(npload_res)
  111. loss_train = data['loss_train']
  112. loss_valid = data['loss_valid'][1:]
  113. f1_valid = data['f1'][1:]
  114. fig, ax = plt.subplots(ncols=3, figsize=(9, 2))
  115. ax[0].plot(range(len(loss_train)), loss_train)
  116. ax[0].set_title("Training set loss")
  117. ax[1].plot(range(1, len(loss_valid) + 1), loss_valid)
  118. ax[1].set_title("Preliminary test set loss")
  119. ax[2].plot(range(1, len(f1_valid) + 1), f1_valid)
  120. ax[2].set_title("Preliminary test set F1-score")
  121. fig.tight_layout()
  122. fig.savefig('results/plot_curves.pdf')
  123. ######
  124. plt.show()