167 lines
4.7 KiB
Python
167 lines
4.7 KiB
Python
from glob import glob
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from os.path import join
|
|
|
|
from scipy.stats import mannwhitneyu
|
|
|
|
dataroots = {
|
|
'PROPOSAL' : 'results',
|
|
#'model_cnmc_res_128' : 'results/model_cnmc_res_128',
|
|
#'model_cnmc_res_224' : 'results/model_cnmc_res_224',
|
|
#'model_cnmc_res_256' : 'results/model_cnmc_res_256',
|
|
#'model_cnmc_res_450' : 'results/model_cnmc_res_450',
|
|
#'model_cnmc_res_450_blue_only' : 'results/model_cnmc_res_450_blue_only',
|
|
#'model_cnmc_res_450_green_only' : 'results/model_cnmc_res_450_green_only',
|
|
#'model_cnmc_res_450_red_only' : 'results/model_cnmc_res_450_red_only',
|
|
#'model_cnmc_res_450_no_blue' : 'results/model_cnmc_res_450_no_blue',
|
|
#'model_cnmc_res_450_no_green' : 'results/model_cnmc_res_450_no_green',
|
|
#'model_cnmc_res_450_no_red' : 'results/model_cnmc_res_450_no_red',
|
|
#'model_cnmc_res_450_grayscale' : 'results/model_cnmc_res_450_grayscale',
|
|
}
|
|
|
|
|
|
|
|
|
|
def get_values(dataroot, key):
|
|
npzs = list(glob(join(dataroot, '*', 'results.npz')))
|
|
vals = []
|
|
for f in npzs:
|
|
recorded_data = np.load(f)
|
|
val = recorded_data[key]
|
|
vals.append(val)
|
|
vals = np.stack(vals, 0)
|
|
return vals
|
|
|
|
|
|
def plot_mean_std(dataroot, key, ax, **kwargs):
|
|
vals = get_values(dataroot, key)
|
|
mean = np.mean(vals, 0)
|
|
std = np.std(vals, 0)
|
|
epochs = np.arange(len(mean))
|
|
|
|
# Offset by 1 so that we have nicely zoomed plots
|
|
mean = mean[1:]
|
|
std = std[1:]
|
|
epochs = epochs[1:]
|
|
|
|
ax.plot(epochs, mean, **kwargs)
|
|
ax.fill_between(epochs, mean - std, mean + std, alpha=0.2)
|
|
|
|
|
|
def plot3(key, ax):
|
|
for k, v in dataroots.items():
|
|
plot_mean_std(v, key, ax, label=k)
|
|
|
|
|
|
def print_final_min_mean_max(dataroot, key, model_epochs):
|
|
vals = get_values(dataroot, key) * 100
|
|
vals = vals[np.arange(len(vals)), model_epochs]
|
|
min = np.min(vals)
|
|
mean = np.mean(vals)
|
|
std = np.std(vals)
|
|
max = np.max(vals)
|
|
print(f'{min:.2f}', f'{mean:.2f} ± {std:.2f}', f'{max:.2f}', sep='\t')
|
|
|
|
|
|
def print_final_table(dataroot):
|
|
best_model_epochs = np.argmax(get_values(dataroot, 'f1'), axis=1)
|
|
|
|
print_final_min_mean_max(dataroot, 'acc', best_model_epochs)
|
|
print_final_min_mean_max(dataroot, 'acc_all', best_model_epochs)
|
|
print_final_min_mean_max(dataroot, 'acc_hem', best_model_epochs)
|
|
print_final_min_mean_max(dataroot, 'f1', best_model_epochs)
|
|
print_final_min_mean_max(dataroot, 'precision', best_model_epochs)
|
|
print_final_min_mean_max(dataroot, 'recall', best_model_epochs)
|
|
|
|
|
|
def get_best_f1_scores(dataroot):
|
|
f1_scores = get_values(dataroot, 'f1')
|
|
best_model_epochs = np.argmax(f1_scores, axis=1)
|
|
return f1_scores[np.arange(len(f1_scores)), best_model_epochs]
|
|
|
|
|
|
def is_statistically_greater(dataroot1, dataroot2):
|
|
# Tests if F1-score of dataroot1 is greater than dataroot2
|
|
a = get_best_f1_scores(dataroot1)
|
|
b = get_best_f1_scores(dataroot2)
|
|
u, p = mannwhitneyu(a, b, alternative='greater')
|
|
return u, p
|
|
|
|
|
|
######
|
|
|
|
for k, v in dataroots.items():
|
|
print(k)
|
|
print_final_table(v)
|
|
print()
|
|
|
|
|
|
######
|
|
|
|
#print("MWU-Test of PROPOSAL > NOSPECLR")
|
|
#print(is_statistically_greater(dataroots['PROPOSAL'], dataroots['NOSPECLR']))
|
|
#print()
|
|
#print("MWU-Test of PROPOSAL > NOROT")
|
|
#print(is_statistically_greater(dataroots['PROPOSAL'], dataroots['NOROT']))
|
|
|
|
######
|
|
|
|
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(9, 5))
|
|
|
|
ax[0, 0].set_title('Accuracy')
|
|
plot3('acc', ax[0, 0])
|
|
|
|
ax[0, 1].set_title('Sensitivity')
|
|
plot3('acc_all', ax[0, 1])
|
|
|
|
ax[0, 2].set_title('Specificity')
|
|
plot3('acc_hem', ax[0, 2])
|
|
|
|
ax[1, 0].set_title('F1 score')
|
|
plot3('f1', ax[1, 0])
|
|
|
|
ax[1, 1].set_title('Precision')
|
|
plot3('precision', ax[1, 1])
|
|
|
|
ax[1, 2].set_title('Recall')
|
|
plot3('recall', ax[1, 2])
|
|
|
|
fig.legend(loc='lower center', ncol=3)
|
|
fig.tight_layout()
|
|
fig.subplots_adjust(bottom=0.12)
|
|
fig.savefig('results/plot_ablations.pdf')
|
|
|
|
######
|
|
npload= 'results/model_cnmc_res_128'
|
|
npload_sub=npload + '/subj_acc.npz'
|
|
npload_res=npload + '/results.npz'
|
|
subj_acc = np.load(npload_sub)
|
|
subj = list(sorted(subj_acc.keys()))
|
|
acc = [subj_acc[k] for k in subj]
|
|
fig, ax = plt.subplots(figsize=(9, 2))
|
|
ax.bar(range(len(acc)), acc, width=0.3, tick_label=subj)
|
|
fig.tight_layout()
|
|
fig.savefig('results/plot_subj_acc.pdf')
|
|
|
|
######
|
|
|
|
data = np.load(npload_res)
|
|
loss_train = data['loss_train']
|
|
loss_valid = data['loss_valid'][1:]
|
|
f1_valid = data['f1'][1:]
|
|
fig, ax = plt.subplots(ncols=3, figsize=(9, 2))
|
|
ax[0].plot(range(len(loss_train)), loss_train)
|
|
ax[0].set_title("Training set loss")
|
|
ax[1].plot(range(1, len(loss_valid) + 1), loss_valid)
|
|
ax[1].set_title("Preliminary test set loss")
|
|
ax[2].plot(range(1, len(f1_valid) + 1), f1_valid)
|
|
ax[2].set_title("Preliminary test set F1-score")
|
|
fig.tight_layout()
|
|
fig.savefig('results/plot_curves.pdf')
|
|
|
|
######
|
|
|
|
plt.show()
|