In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 1.8KB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from common.eval import *
  2. def main():
  3. model.eval()
  4. if P.mode == 'test_acc':
  5. from evals import test_classifier
  6. with torch.no_grad():
  7. error = test_classifier(P, model, test_loader, 0, logger=None)
  8. elif P.mode == 'test_marginalized_acc':
  9. from evals import test_classifier
  10. with torch.no_grad():
  11. error = test_classifier(P, model, test_loader, 0, marginal=True, logger=None)
  12. elif P.mode in ['ood', 'ood_pre']:
  13. if P.mode == 'ood':
  14. from evals import eval_ood_detection
  15. else:
  16. from evals.ood_pre import eval_ood_detection
  17. with torch.no_grad():
  18. auroc_dict = eval_ood_detection(P, model, test_loader, ood_test_loader, P.ood_score,
  19. train_loader=train_loader, simclr_aug=simclr_aug)
  20. if P.one_class_idx is not None:
  21. mean_dict = dict()
  22. for ood_score in P.ood_score:
  23. mean = 0
  24. for ood in auroc_dict.keys():
  25. mean += auroc_dict[ood][ood_score]
  26. mean_dict[ood_score] = mean / len(auroc_dict.keys())
  27. auroc_dict['one_class_mean'] = mean_dict
  28. bests = []
  29. for ood in auroc_dict.keys():
  30. message = ''
  31. best_auroc = 0
  32. for ood_score, auroc in auroc_dict[ood].items():
  33. message += '[%s %s %.4f] ' % (ood, ood_score, auroc)
  34. if auroc > best_auroc:
  35. best_auroc = auroc
  36. message += '[%s %s %.4f] ' % (ood, 'best', best_auroc)
  37. if P.print_score:
  38. print(message)
  39. bests.append(best_auroc)
  40. bests = map('{:.4f}'.format, bests)
  41. print('\t'.join(bests))
  42. else:
  43. raise NotImplementedError()
  44. if __name__ == '__main__':
  45. main()