123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- from glob import glob
-
- import numpy as np
- import matplotlib.pyplot as plt
- from os.path import join
-
- from scipy.stats import mannwhitneyu
-
- dataroots = {
- 'PROPOSAL' : 'results',
- #'model_cnmc_res_128' : 'results/model_cnmc_res_128',
- #'model_cnmc_res_224' : 'results/model_cnmc_res_224',
- #'model_cnmc_res_256' : 'results/model_cnmc_res_256',
- #'model_cnmc_res_450' : 'results/model_cnmc_res_450',
- #'model_cnmc_res_450_blue_only' : 'results/model_cnmc_res_450_blue_only',
- #'model_cnmc_res_450_green_only' : 'results/model_cnmc_res_450_green_only',
- #'model_cnmc_res_450_red_only' : 'results/model_cnmc_res_450_red_only',
- #'model_cnmc_res_450_no_blue' : 'results/model_cnmc_res_450_no_blue',
- #'model_cnmc_res_450_no_green' : 'results/model_cnmc_res_450_no_green',
- #'model_cnmc_res_450_no_red' : 'results/model_cnmc_res_450_no_red',
- #'model_cnmc_res_450_grayscale' : 'results/model_cnmc_res_450_grayscale',
- }
-
-
-
-
- def get_values(dataroot, key):
- npzs = list(glob(join(dataroot, '*', 'results.npz')))
- vals = []
- for f in npzs:
- recorded_data = np.load(f)
- val = recorded_data[key]
- vals.append(val)
- vals = np.stack(vals, 0)
- return vals
-
-
- def plot_mean_std(dataroot, key, ax, **kwargs):
- vals = get_values(dataroot, key)
- mean = np.mean(vals, 0)
- std = np.std(vals, 0)
- epochs = np.arange(len(mean))
-
- # Offset by 1 so that we have nicely zoomed plots
- mean = mean[1:]
- std = std[1:]
- epochs = epochs[1:]
-
- ax.plot(epochs, mean, **kwargs)
- ax.fill_between(epochs, mean - std, mean + std, alpha=0.2)
-
-
- def plot3(key, ax):
- for k, v in dataroots.items():
- plot_mean_std(v, key, ax, label=k)
-
-
- def print_final_min_mean_max(dataroot, key, model_epochs):
- vals = get_values(dataroot, key) * 100
- vals = vals[np.arange(len(vals)), model_epochs]
- min = np.min(vals)
- mean = np.mean(vals)
- std = np.std(vals)
- max = np.max(vals)
- print(f'{min:.2f}', f'{mean:.2f} ± {std:.2f}', f'{max:.2f}', sep='\t')
-
-
- def print_final_table(dataroot):
- best_model_epochs = np.argmax(get_values(dataroot, 'f1'), axis=1)
-
- print_final_min_mean_max(dataroot, 'acc', best_model_epochs)
- print_final_min_mean_max(dataroot, 'acc_all', best_model_epochs)
- print_final_min_mean_max(dataroot, 'acc_hem', best_model_epochs)
- print_final_min_mean_max(dataroot, 'f1', best_model_epochs)
- print_final_min_mean_max(dataroot, 'precision', best_model_epochs)
- print_final_min_mean_max(dataroot, 'recall', best_model_epochs)
-
-
- def get_best_f1_scores(dataroot):
- f1_scores = get_values(dataroot, 'f1')
- best_model_epochs = np.argmax(f1_scores, axis=1)
- return f1_scores[np.arange(len(f1_scores)), best_model_epochs]
-
-
- def is_statistically_greater(dataroot1, dataroot2):
- # Tests if F1-score of dataroot1 is greater than dataroot2
- a = get_best_f1_scores(dataroot1)
- b = get_best_f1_scores(dataroot2)
- u, p = mannwhitneyu(a, b, alternative='greater')
- return u, p
-
-
- ######
-
- for k, v in dataroots.items():
- print(k)
- print_final_table(v)
- print()
-
-
- ######
-
- #print("MWU-Test of PROPOSAL > NOSPECLR")
- #print(is_statistically_greater(dataroots['PROPOSAL'], dataroots['NOSPECLR']))
- #print()
- #print("MWU-Test of PROPOSAL > NOROT")
- #print(is_statistically_greater(dataroots['PROPOSAL'], dataroots['NOROT']))
-
- ######
-
- fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(9, 5))
-
- ax[0, 0].set_title('Accuracy')
- plot3('acc', ax[0, 0])
-
- ax[0, 1].set_title('Sensitivity')
- plot3('acc_all', ax[0, 1])
-
- ax[0, 2].set_title('Specificity')
- plot3('acc_hem', ax[0, 2])
-
- ax[1, 0].set_title('F1 score')
- plot3('f1', ax[1, 0])
-
- ax[1, 1].set_title('Precision')
- plot3('precision', ax[1, 1])
-
- ax[1, 2].set_title('Recall')
- plot3('recall', ax[1, 2])
-
- fig.legend(loc='lower center', ncol=3)
- fig.tight_layout()
- fig.subplots_adjust(bottom=0.12)
- fig.savefig('results/plot_ablations.pdf')
-
- ######
- npload= 'results/model_cnmc_res_128'
- npload_sub=npload + '/subj_acc.npz'
- npload_res=npload + '/results.npz'
- subj_acc = np.load(npload_sub)
- subj = list(sorted(subj_acc.keys()))
- acc = [subj_acc[k] for k in subj]
- fig, ax = plt.subplots(figsize=(9, 2))
- ax.bar(range(len(acc)), acc, width=0.3, tick_label=subj)
- fig.tight_layout()
- fig.savefig('results/plot_subj_acc.pdf')
-
- ######
-
- data = np.load(npload_res)
- loss_train = data['loss_train']
- loss_valid = data['loss_valid'][1:]
- f1_valid = data['f1'][1:]
- fig, ax = plt.subplots(ncols=3, figsize=(9, 2))
- ax[0].plot(range(len(loss_train)), loss_train)
- ax[0].set_title("Training set loss")
- ax[1].plot(range(1, len(loss_valid) + 1), loss_valid)
- ax[1].set_title("Preliminary test set loss")
- ax[2].plot(range(1, len(f1_valid) + 1), f1_valid)
- ax[2].set_title("Preliminary test set F1-score")
- fig.tight_layout()
- fig.savefig('results/plot_curves.pdf')
-
- ######
-
- plt.show()
|