In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evals.py 6.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import time
  2. import itertools
  3. import diffdist.functional as distops
  4. import numpy as np
  5. import torch
  6. import torch.distributed as dist
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from sklearn.metrics import roc_auc_score
  10. import models.transform_layers as TL
  11. from utils.temperature_scaling import _ECELoss
  12. from utils.utils import AverageMeter, set_random_seed, normalize
  13. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  14. ece_criterion = _ECELoss().to(device)
  15. def error_k(output, target, ks=(1,)):
  16. """Computes the precision@k for the specified values of k"""
  17. max_k = max(ks)
  18. batch_size = target.size(0)
  19. _, pred = output.topk(max_k, 1, True, True)
  20. pred = pred.t()
  21. correct = pred.eq(target.view(1, -1).expand_as(pred))
  22. results = []
  23. for k in ks:
  24. correct_k = correct[:k].view(-1).float().sum(0)
  25. results.append(100.0 - correct_k.mul_(100.0 / batch_size))
  26. return results
  27. def test_classifier(P, model, loader, steps, marginal=False, logger=None):
  28. error_top1 = AverageMeter()
  29. error_calibration = AverageMeter()
  30. if logger is None:
  31. log_ = print
  32. else:
  33. log_ = logger.log
  34. # Switch to evaluate mode
  35. mode = model.training
  36. model.eval()
  37. for n, (images, labels) in enumerate(loader):
  38. batch_size = images.size(0)
  39. images, labels = images.to(device), labels.to(device)
  40. if marginal:
  41. outputs = 0
  42. for i in range(4):
  43. rot_images = torch.rot90(images, i, (2, 3))
  44. _, outputs_aux = model(rot_images, joint=True)
  45. outputs += outputs_aux['joint'][:, P.n_classes * i: P.n_classes * (i + 1)] / 4.
  46. else:
  47. outputs = model(images)
  48. top1, = error_k(outputs.data, labels, ks=(1,))
  49. error_top1.update(top1.item(), batch_size)
  50. ece = ece_criterion(outputs, labels) * 100
  51. error_calibration.update(ece.item(), batch_size)
  52. if n % 100 == 0:
  53. log_('[Test %3d] [Test@1 %.3f] [ECE %.3f]' %
  54. (n, error_top1.value, error_calibration.value))
  55. log_(' * [Error@1 %.3f] [ECE %.3f]' %
  56. (error_top1.average, error_calibration.average))
  57. if logger is not None:
  58. logger.scalar_summary('eval/clean_error', error_top1.average, steps)
  59. logger.scalar_summary('eval/ece', error_calibration.average, steps)
  60. model.train(mode)
  61. return error_top1.average
  62. def eval_ood_detection(P, model, id_loader, ood_loaders, ood_scores, train_loader=None, simclr_aug=None):
  63. auroc_dict = dict()
  64. for ood in ood_loaders.keys():
  65. auroc_dict[ood] = dict()
  66. for ood_score in ood_scores:
  67. # compute scores for ID and OOD samples
  68. score_func = get_ood_score_func(P, model, ood_score, simclr_aug=simclr_aug)
  69. save_path = f'plot/score_in_{P.dataset}_{ood_score}'
  70. if P.one_class_idx is not None:
  71. save_path += f'_{P.one_class_idx}'
  72. scores_id = get_scores(id_loader, score_func)
  73. if P.save_score:
  74. np.save(f'{save_path}.npy', scores_id)
  75. for ood, ood_loader in ood_loaders.items():
  76. if ood == 'interp':
  77. scores_ood = get_scores_interp(id_loader, score_func)
  78. auroc_dict['interp'][ood_score] = get_auroc(scores_id, scores_ood)
  79. else:
  80. scores_ood = get_scores(ood_loader, score_func)
  81. auroc_dict[ood][ood_score] = get_auroc(scores_id, scores_ood)
  82. if P.save_score:
  83. np.save(f'{save_path}_out_{ood}.npy', scores_ood)
  84. return auroc_dict
  85. def get_ood_score_func(P, model, ood_score, simclr_aug=None):
  86. def score_func(x):
  87. return compute_ood_score(P, model, ood_score, x, simclr_aug=simclr_aug)
  88. return score_func
  89. def get_scores(loader, score_func):
  90. scores = []
  91. for i, (x, _) in enumerate(loader):
  92. s = score_func(x.to(device))
  93. assert s.dim() == 1 and s.size(0) == x.size(0)
  94. scores.append(s.detach().cpu().numpy())
  95. return np.concatenate(scores)
  96. def get_scores_interp(loader, score_func):
  97. scores = []
  98. for i, (x, _) in enumerate(loader):
  99. x_interp = (x + last) / 2 if i > 0 else x # omit the first batch, assume batch sizes are equal
  100. last = x # save the last batch
  101. s = score_func(x_interp.to(device))
  102. assert s.dim() == 1 and s.size(0) == x.size(0)
  103. scores.append(s.detach().cpu().numpy())
  104. return np.concatenate(scores)
  105. def get_auroc(scores_id, scores_ood):
  106. scores = np.concatenate([scores_id, scores_ood])
  107. labels = np.concatenate([np.ones_like(scores_id), np.zeros_like(scores_ood)])
  108. return roc_auc_score(labels, scores)
  109. def compute_ood_score(P, model, ood_score, x, simclr_aug=None):
  110. model.eval()
  111. if ood_score == 'clean_norm':
  112. _, output_aux = model(x, penultimate=True, simclr=True)
  113. score = output_aux[P.ood_layer].norm(dim=1)
  114. return score
  115. elif ood_score == 'similar':
  116. assert simclr_aug is not None # require custom simclr augmentation
  117. sample_num = 2 # fast evaluation
  118. feats = get_features(model, simclr_aug, x, layer=P.ood_layer, sample_num=sample_num)
  119. feats_avg = sum(feats) / len(feats)
  120. scores = []
  121. for seed in range(sample_num):
  122. sim = torch.cosine_similarity(feats[seed], feats_avg)
  123. scores.append(sim)
  124. return sum(scores) / len(scores)
  125. elif ood_score == 'baseline':
  126. outputs, outputs_aux = model(x, penultimate=True)
  127. scores = F.softmax(outputs, dim=1).max(dim=1)[0]
  128. return scores
  129. elif ood_score == 'baseline_marginalized':
  130. total_outputs = 0
  131. for i in range(4):
  132. x_rot = torch.rot90(x, i, (2, 3))
  133. outputs, outputs_aux = model(x_rot, penultimate=True, joint=True)
  134. total_outputs += outputs_aux['joint'][:, P.n_classes * i:P.n_classes * (i + 1)]
  135. scores = F.softmax(total_outputs / 4., dim=1).max(dim=1)[0]
  136. return scores
  137. else:
  138. raise NotImplementedError()
  139. def get_features(model, simclr_aug, x, layer='simclr', sample_num=1):
  140. model.eval()
  141. feats = []
  142. for seed in range(sample_num):
  143. set_random_seed(seed)
  144. x_t = simclr_aug(x)
  145. with torch.no_grad():
  146. _, output_aux = model(x_t, penultimate=True, simclr=True, shift=True)
  147. feats.append(output_aux[layer])
  148. return feats