In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ood_pre.py 8.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import os
  2. from copy import deepcopy
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import numpy as np
  7. import models.transform_layers as TL
  8. from utils.utils import set_random_seed, normalize
  9. from evals.evals import get_auroc
  10. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  11. hflip = TL.HorizontalFlipLayer().to(device)
  12. def eval_ood_detection(P, model, id_loader, ood_loaders, ood_scores, train_loader=None, simclr_aug=None):
  13. auroc_dict = dict()
  14. for ood in ood_loaders.keys():
  15. auroc_dict[ood] = dict()
  16. assert len(ood_scores) == 1 # assume single ood_score for simplicity
  17. ood_score = ood_scores[0]
  18. base_path = os.path.split(P.load_path)[0] # checkpoint directory
  19. prefix = f'{P.ood_samples}'
  20. if P.resize_fix:
  21. prefix += f'_resize_fix_{P.resize_factor}'
  22. else:
  23. prefix += f'_resize_range_{P.resize_factor}'
  24. prefix = os.path.join(base_path, f'feats_{prefix}')
  25. kwargs = {
  26. 'simclr_aug': simclr_aug,
  27. 'sample_num': P.ood_samples,
  28. 'layers': P.ood_layer,
  29. }
  30. print('Pre-compute global statistics...')
  31. feats_train = get_features(P, f'{P.dataset}_train', model, train_loader, prefix=prefix, **kwargs) # (M, T, d)
  32. P.axis = []
  33. for f in feats_train['simclr'].chunk(P.K_shift, dim=1):
  34. axis = f.mean(dim=1) # (M, d)
  35. P.axis.append(normalize(axis, dim=1).to(device))
  36. print('axis size: ' + ' '.join(map(lambda x: str(len(x)), P.axis)))
  37. f_sim = [f.mean(dim=1) for f in feats_train['simclr'].chunk(P.K_shift, dim=1)] # list of (M, d)
  38. f_shi = [f.mean(dim=1) for f in feats_train['shift'].chunk(P.K_shift, dim=1)] # list of (M, 4)
  39. weight_sim = []
  40. weight_shi = []
  41. for shi in range(P.K_shift):
  42. sim_norm = f_sim[shi].norm(dim=1) # (M)
  43. shi_mean = f_shi[shi][:, shi] # (M)
  44. weight_sim.append(1 / sim_norm.mean().item())
  45. weight_shi.append(1 / shi_mean.mean().item())
  46. if ood_score == 'simclr':
  47. P.weight_sim = [1]
  48. P.weight_shi = [0]
  49. elif ood_score == 'CSI':
  50. P.weight_sim = weight_sim
  51. P.weight_shi = weight_shi
  52. else:
  53. raise ValueError()
  54. print(f'weight_sim:\t' + '\t'.join(map('{:.4f}'.format, P.weight_sim)))
  55. print(f'weight_shi:\t' + '\t'.join(map('{:.4f}'.format, P.weight_shi)))
  56. print('Pre-compute features...')
  57. feats_id = get_features(P, P.dataset, model, id_loader, prefix=prefix, **kwargs) # (N, T, d)
  58. feats_ood = dict()
  59. for ood, ood_loader in ood_loaders.items():
  60. if ood == 'interp':
  61. feats_ood[ood] = get_features(P, ood, model, id_loader, interp=True, prefix=prefix, **kwargs)
  62. else:
  63. feats_ood[ood] = get_features(P, ood, model, ood_loader, prefix=prefix, **kwargs)
  64. print(f'Compute OOD scores... (score: {ood_score})')
  65. scores_id = get_scores(P, feats_id, ood_score).numpy()
  66. scores_ood = dict()
  67. if P.one_class_idx is not None:
  68. one_class_score = []
  69. for ood, feats in feats_ood.items():
  70. scores_ood[ood] = get_scores(P, feats, ood_score).numpy()
  71. auroc_dict[ood][ood_score] = get_auroc(scores_id, scores_ood[ood])
  72. if P.one_class_idx is not None:
  73. one_class_score.append(scores_ood[ood])
  74. if P.one_class_idx is not None:
  75. one_class_score = np.concatenate(one_class_score)
  76. one_class_total = get_auroc(scores_id, one_class_score)
  77. print(f'One_class_real_mean: {one_class_total}')
  78. if P.print_score:
  79. print_score(P.dataset, scores_id)
  80. for ood, scores in scores_ood.items():
  81. print_score(ood, scores)
  82. return auroc_dict
  83. def get_scores(P, feats_dict, ood_score):
  84. # convert to gpu tensor
  85. feats_sim = feats_dict['simclr'].to(device)
  86. feats_shi = feats_dict['shift'].to(device)
  87. N = feats_sim.size(0)
  88. # compute scores
  89. scores = []
  90. for f_sim, f_shi in zip(feats_sim, feats_shi):
  91. f_sim = [f.mean(dim=0, keepdim=True) for f in f_sim.chunk(P.K_shift)] # list of (1, d)
  92. f_shi = [f.mean(dim=0, keepdim=True) for f in f_shi.chunk(P.K_shift)] # list of (1, 4)
  93. score = 0
  94. for shi in range(P.K_shift):
  95. score += (f_sim[shi] * P.axis[shi]).sum(dim=1).max().item() * P.weight_sim[shi]
  96. score += f_shi[shi][:, shi].item() * P.weight_shi[shi]
  97. score = score / P.K_shift
  98. scores.append(score)
  99. scores = torch.tensor(scores)
  100. assert scores.dim() == 1 and scores.size(0) == N # (N)
  101. return scores.cpu()
  102. def get_features(P, data_name, model, loader, interp=False, prefix='',
  103. simclr_aug=None, sample_num=1, layers=('simclr', 'shift')):
  104. if not isinstance(layers, (list, tuple)):
  105. layers = [layers]
  106. # load pre-computed features if exists
  107. feats_dict = dict()
  108. # for layer in layers:
  109. # path = prefix + f'_{data_name}_{layer}.pth'
  110. # if os.path.exists(path):
  111. # feats_dict[layer] = torch.load(path)
  112. # pre-compute features and save to the path
  113. left = [layer for layer in layers if layer not in feats_dict.keys()]
  114. if len(left) > 0:
  115. _feats_dict = _get_features(P, model, loader, interp, (P.dataset == 'imagenet' or
  116. P.dataset == 'CNMC' or
  117. P.dataset == 'CNMC_grayscale'),
  118. simclr_aug, sample_num, layers=left)
  119. for layer, feats in _feats_dict.items():
  120. path = prefix + f'_{data_name}_{layer}.pth'
  121. torch.save(_feats_dict[layer], path)
  122. feats_dict[layer] = feats # update value
  123. return feats_dict
  124. def _get_features(P, model, loader, interp=False, imagenet=False, simclr_aug=None,
  125. sample_num=1, layers=('simclr', 'shift')):
  126. if not isinstance(layers, (list, tuple)):
  127. layers = [layers]
  128. # check if arguments are valid
  129. assert simclr_aug is not None
  130. if imagenet is True: # assume batch_size = 1 for ImageNet
  131. sample_num = 1
  132. # compute features in full dataset
  133. model.eval()
  134. feats_all = {layer: [] for layer in layers} # initialize: empty list
  135. for i, (x, _) in enumerate(loader):
  136. if interp:
  137. x_interp = (x + last) / 2 if i > 0 else x # omit the first batch, assume batch sizes are equal
  138. last = x # save the last batch
  139. x = x_interp # use interp as current batch
  140. if imagenet is True:
  141. x = torch.cat(x[0], dim=0) # augmented list of x
  142. x = x.to(device) # gpu tensor
  143. # compute features in one batch
  144. feats_batch = {layer: [] for layer in layers} # initialize: empty list
  145. for seed in range(sample_num):
  146. set_random_seed(seed)
  147. if P.K_shift > 1:
  148. x_t = torch.cat([P.shift_trans(hflip(x), k) for k in range(P.K_shift)])
  149. else:
  150. x_t = x # No shifting: SimCLR
  151. x_t = simclr_aug(x_t)
  152. # compute augmented features
  153. with torch.no_grad():
  154. kwargs = {layer: True for layer in layers} # only forward selected layers
  155. _, output_aux = model(x_t, **kwargs)
  156. # add features in one batch
  157. for layer in layers:
  158. feats = output_aux[layer].cpu()
  159. if imagenet is False:
  160. feats_batch[layer] += feats.chunk(P.K_shift)
  161. else:
  162. feats_batch[layer] += [feats] # (B, d) cpu tensor
  163. # concatenate features in one batch
  164. for key, val in feats_batch.items():
  165. if imagenet:
  166. feats_batch[key] = torch.stack(val, dim=0) # (B, T, d)
  167. else:
  168. feats_batch[key] = torch.stack(val, dim=1) # (B, T, d)
  169. # add features in full dataset
  170. for layer in layers:
  171. feats_all[layer] += [feats_batch[layer]]
  172. # concatenate features in full dataset
  173. for key, val in feats_all.items():
  174. feats_all[key] = torch.cat(val, dim=0) # (N, T, d)
  175. # reshape order
  176. if imagenet is False:
  177. # Convert [1,2,3,4, 1,2,3,4] -> [1,1, 2,2, 3,3, 4,4]
  178. for key, val in feats_all.items():
  179. N, T, d = val.size() # T = K * T'
  180. val = val.view(N, -1, P.K_shift, d) # (N, T', K, d)
  181. val = val.transpose(2, 1) # (N, 4, T', d)
  182. val = val.reshape(N, T, d) # (N, T, d)
  183. feats_all[key] = val
  184. return feats_all
  185. def print_score(data_name, scores):
  186. quantile = np.quantile(scores, np.arange(0, 1.1, 0.1))
  187. print('{:18s} '.format(data_name) +
  188. '{:.4f} +- {:.4f} '.format(np.mean(scores), np.std(scores)) +
  189. ' '.join(['q{:d}: {:.4f}'.format(i * 10, quantile[i]) for i in range(11)]))