|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- from common.eval import *
-
-
- def main():
- model.eval()
-
- if P.mode == 'test_acc':
- from evals import test_classifier
- with torch.no_grad():
- error = test_classifier(P, model, test_loader, 0, logger=None)
-
- elif P.mode == 'test_marginalized_acc':
- from evals import test_classifier
- with torch.no_grad():
- error = test_classifier(P, model, test_loader, 0, marginal=True, logger=None)
-
- elif P.mode in ['ood', 'ood_pre']:
- if P.mode == 'ood':
- from evals import eval_ood_detection
- else:
- from evals.ood_pre import eval_ood_detection
-
- with torch.no_grad():
- auroc_dict = eval_ood_detection(P, model, test_loader, ood_test_loader, P.ood_score,
- train_loader=train_loader, simclr_aug=simclr_aug)
-
- if P.one_class_idx is not None:
- mean_dict = dict()
- for ood_score in P.ood_score:
- mean = 0
- for ood in auroc_dict.keys():
- mean += auroc_dict[ood][ood_score]
- mean_dict[ood_score] = mean / len(auroc_dict.keys())
- auroc_dict['one_class_mean'] = mean_dict
-
- bests = []
- for ood in auroc_dict.keys():
- message = ''
- best_auroc = 0
- for ood_score, auroc in auroc_dict[ood].items():
- message += '[%s %s %.4f] ' % (ood, ood_score, auroc)
- if auroc > best_auroc:
- best_auroc = auroc
- message += '[%s %s %.4f] ' % (ood, 'best', best_auroc)
- if P.print_score:
- print(message)
- bests.append(best_auroc)
-
- bests = map('{:.4f}'.format, bests)
- print('\t'.join(bests))
-
- else:
- raise NotImplementedError()
-
-
- if __name__ == '__main__':
- main()
|