from glob import glob import numpy as np import matplotlib.pyplot as plt from os.path import join from scipy.stats import mannwhitneyu dataroots = { 'PROPOSAL' : 'results', #'model_cnmc_res_128' : 'results/model_cnmc_res_128', #'model_cnmc_res_224' : 'results/model_cnmc_res_224', #'model_cnmc_res_256' : 'results/model_cnmc_res_256', #'model_cnmc_res_450' : 'results/model_cnmc_res_450', #'model_cnmc_res_450_blue_only' : 'results/model_cnmc_res_450_blue_only', #'model_cnmc_res_450_green_only' : 'results/model_cnmc_res_450_green_only', #'model_cnmc_res_450_red_only' : 'results/model_cnmc_res_450_red_only', #'model_cnmc_res_450_no_blue' : 'results/model_cnmc_res_450_no_blue', #'model_cnmc_res_450_no_green' : 'results/model_cnmc_res_450_no_green', #'model_cnmc_res_450_no_red' : 'results/model_cnmc_res_450_no_red', #'model_cnmc_res_450_grayscale' : 'results/model_cnmc_res_450_grayscale', } def get_values(dataroot, key): npzs = list(glob(join(dataroot, '*', 'results.npz'))) vals = [] for f in npzs: recorded_data = np.load(f) val = recorded_data[key] vals.append(val) vals = np.stack(vals, 0) return vals def plot_mean_std(dataroot, key, ax, **kwargs): vals = get_values(dataroot, key) mean = np.mean(vals, 0) std = np.std(vals, 0) epochs = np.arange(len(mean)) # Offset by 1 so that we have nicely zoomed plots mean = mean[1:] std = std[1:] epochs = epochs[1:] ax.plot(epochs, mean, **kwargs) ax.fill_between(epochs, mean - std, mean + std, alpha=0.2) def plot3(key, ax): for k, v in dataroots.items(): plot_mean_std(v, key, ax, label=k) def print_final_min_mean_max(dataroot, key, model_epochs): vals = get_values(dataroot, key) * 100 vals = vals[np.arange(len(vals)), model_epochs] min = np.min(vals) mean = np.mean(vals) std = np.std(vals) max = np.max(vals) print(f'{min:.2f}', f'{mean:.2f} ± {std:.2f}', f'{max:.2f}', sep='\t') def print_final_table(dataroot): best_model_epochs = np.argmax(get_values(dataroot, 'f1'), axis=1) print_final_min_mean_max(dataroot, 'acc', best_model_epochs) print_final_min_mean_max(dataroot, 'acc_all', best_model_epochs) print_final_min_mean_max(dataroot, 'acc_hem', best_model_epochs) print_final_min_mean_max(dataroot, 'f1', best_model_epochs) print_final_min_mean_max(dataroot, 'precision', best_model_epochs) print_final_min_mean_max(dataroot, 'recall', best_model_epochs) def get_best_f1_scores(dataroot): f1_scores = get_values(dataroot, 'f1') best_model_epochs = np.argmax(f1_scores, axis=1) return f1_scores[np.arange(len(f1_scores)), best_model_epochs] def is_statistically_greater(dataroot1, dataroot2): # Tests if F1-score of dataroot1 is greater than dataroot2 a = get_best_f1_scores(dataroot1) b = get_best_f1_scores(dataroot2) u, p = mannwhitneyu(a, b, alternative='greater') return u, p ###### for k, v in dataroots.items(): print(k) print_final_table(v) print() ###### #print("MWU-Test of PROPOSAL > NOSPECLR") #print(is_statistically_greater(dataroots['PROPOSAL'], dataroots['NOSPECLR'])) #print() #print("MWU-Test of PROPOSAL > NOROT") #print(is_statistically_greater(dataroots['PROPOSAL'], dataroots['NOROT'])) ###### fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(9, 5)) ax[0, 0].set_title('Accuracy') plot3('acc', ax[0, 0]) ax[0, 1].set_title('Sensitivity') plot3('acc_all', ax[0, 1]) ax[0, 2].set_title('Specificity') plot3('acc_hem', ax[0, 2]) ax[1, 0].set_title('F1 score') plot3('f1', ax[1, 0]) ax[1, 1].set_title('Precision') plot3('precision', ax[1, 1]) ax[1, 2].set_title('Recall') plot3('recall', ax[1, 2]) fig.legend(loc='lower center', ncol=3) fig.tight_layout() fig.subplots_adjust(bottom=0.12) fig.savefig('results/plot_ablations.pdf') ###### npload= 'results/model_cnmc_res_128' npload_sub=npload + '/subj_acc.npz' npload_res=npload + '/results.npz' subj_acc = np.load(npload_sub) subj = list(sorted(subj_acc.keys())) acc = [subj_acc[k] for k in subj] fig, ax = plt.subplots(figsize=(9, 2)) ax.bar(range(len(acc)), acc, width=0.3, tick_label=subj) fig.tight_layout() fig.savefig('results/plot_subj_acc.pdf') ###### data = np.load(npload_res) loss_train = data['loss_train'] loss_valid = data['loss_valid'][1:] f1_valid = data['f1'][1:] fig, ax = plt.subplots(ncols=3, figsize=(9, 2)) ax[0].plot(range(len(loss_train)), loss_train) ax[0].set_title("Training set loss") ax[1].plot(range(1, len(loss_valid) + 1), loss_valid) ax[1].set_title("Preliminary test set loss") ax[2].plot(range(1, len(f1_valid) + 1), f1_valid) ax[2].set_title("Preliminary test set F1-score") fig.tight_layout() fig.savefig('results/plot_curves.pdf') ###### plt.show()