123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 |
- import os
- from copy import deepcopy
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import numpy as np
-
- import models.transform_layers as TL
- from utils.utils import set_random_seed, normalize
- from evals.evals import get_auroc
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- hflip = TL.HorizontalFlipLayer().to(device)
-
-
- def eval_ood_detection(P, model, id_loader, ood_loaders, ood_scores, train_loader=None, simclr_aug=None):
- auroc_dict = dict()
- for ood in ood_loaders.keys():
- auroc_dict[ood] = dict()
-
- assert len(ood_scores) == 1 # assume single ood_score for simplicity
- ood_score = ood_scores[0]
-
- base_path = os.path.split(P.load_path)[0] # checkpoint directory
-
- prefix = f'{P.ood_samples}'
- if P.resize_fix:
- prefix += f'_resize_fix_{P.resize_factor}'
- else:
- prefix += f'_resize_range_{P.resize_factor}'
-
- prefix = os.path.join(base_path, f'feats_{prefix}')
-
- kwargs = {
- 'simclr_aug': simclr_aug,
- 'sample_num': P.ood_samples,
- 'layers': P.ood_layer,
- }
-
- print('Pre-compute global statistics...')
- feats_train = get_features(P, f'{P.dataset}_train', model, train_loader, prefix=prefix, **kwargs) # (M, T, d)
-
- P.axis = []
- for f in feats_train['simclr'].chunk(P.K_shift, dim=1):
- axis = f.mean(dim=1) # (M, d)
- P.axis.append(normalize(axis, dim=1).to(device))
- print('axis size: ' + ' '.join(map(lambda x: str(len(x)), P.axis)))
-
- f_sim = [f.mean(dim=1) for f in feats_train['simclr'].chunk(P.K_shift, dim=1)] # list of (M, d)
- f_shi = [f.mean(dim=1) for f in feats_train['shift'].chunk(P.K_shift, dim=1)] # list of (M, 4)
-
- weight_sim = []
- weight_shi = []
- for shi in range(P.K_shift):
- sim_norm = f_sim[shi].norm(dim=1) # (M)
- shi_mean = f_shi[shi][:, shi] # (M)
- weight_sim.append(1 / sim_norm.mean().item())
- weight_shi.append(1 / shi_mean.mean().item())
-
- if ood_score == 'simclr':
- P.weight_sim = [1]
- P.weight_shi = [0]
- elif ood_score == 'CSI':
- P.weight_sim = weight_sim
- P.weight_shi = weight_shi
- else:
- raise ValueError()
-
- print(f'weight_sim:\t' + '\t'.join(map('{:.4f}'.format, P.weight_sim)))
- print(f'weight_shi:\t' + '\t'.join(map('{:.4f}'.format, P.weight_shi)))
-
- print('Pre-compute features...')
- feats_id = get_features(P, P.dataset, model, id_loader, prefix=prefix, **kwargs) # (N, T, d)
- feats_ood = dict()
- for ood, ood_loader in ood_loaders.items():
- if ood == 'interp':
- feats_ood[ood] = get_features(P, ood, model, id_loader, interp=True, prefix=prefix, **kwargs)
- else:
- feats_ood[ood] = get_features(P, ood, model, ood_loader, prefix=prefix, **kwargs)
-
- print(f'Compute OOD scores... (score: {ood_score})')
- scores_id = get_scores(P, feats_id, ood_score).numpy()
- scores_ood = dict()
- if P.one_class_idx is not None:
- one_class_score = []
-
- for ood, feats in feats_ood.items():
- scores_ood[ood] = get_scores(P, feats, ood_score).numpy()
- auroc_dict[ood][ood_score] = get_auroc(scores_id, scores_ood[ood])
- if P.one_class_idx is not None:
- one_class_score.append(scores_ood[ood])
-
- if P.one_class_idx is not None:
- one_class_score = np.concatenate(one_class_score)
- one_class_total = get_auroc(scores_id, one_class_score)
- print(f'One_class_real_mean: {one_class_total}')
-
- if P.print_score:
- print_score(P.dataset, scores_id)
- for ood, scores in scores_ood.items():
- print_score(ood, scores)
-
- return auroc_dict
-
-
- def get_scores(P, feats_dict, ood_score):
- # convert to gpu tensor
- feats_sim = feats_dict['simclr'].to(device)
- feats_shi = feats_dict['shift'].to(device)
- N = feats_sim.size(0)
-
- # compute scores
- scores = []
- for f_sim, f_shi in zip(feats_sim, feats_shi):
- f_sim = [f.mean(dim=0, keepdim=True) for f in f_sim.chunk(P.K_shift)] # list of (1, d)
- f_shi = [f.mean(dim=0, keepdim=True) for f in f_shi.chunk(P.K_shift)] # list of (1, 4)
- score = 0
- for shi in range(P.K_shift):
- score += (f_sim[shi] * P.axis[shi]).sum(dim=1).max().item() * P.weight_sim[shi]
- score += f_shi[shi][:, shi].item() * P.weight_shi[shi]
- score = score / P.K_shift
- scores.append(score)
- scores = torch.tensor(scores)
-
- assert scores.dim() == 1 and scores.size(0) == N # (N)
- return scores.cpu()
-
-
- def get_features(P, data_name, model, loader, interp=False, prefix='',
- simclr_aug=None, sample_num=1, layers=('simclr', 'shift')):
-
- if not isinstance(layers, (list, tuple)):
- layers = [layers]
-
- # load pre-computed features if exists
- feats_dict = dict()
- # for layer in layers:
- # path = prefix + f'_{data_name}_{layer}.pth'
- # if os.path.exists(path):
- # feats_dict[layer] = torch.load(path)
-
- # pre-compute features and save to the path
- left = [layer for layer in layers if layer not in feats_dict.keys()]
- if len(left) > 0:
- _feats_dict = _get_features(P, model, loader, interp, (P.dataset == 'imagenet' or
- P.dataset == 'CNMC' or
- P.dataset == 'CNMC_grayscale'),
- simclr_aug, sample_num, layers=left)
-
- for layer, feats in _feats_dict.items():
- path = prefix + f'_{data_name}_{layer}.pth'
- torch.save(_feats_dict[layer], path)
- feats_dict[layer] = feats # update value
-
- return feats_dict
-
-
- def _get_features(P, model, loader, interp=False, imagenet=False, simclr_aug=None,
- sample_num=1, layers=('simclr', 'shift')):
-
- if not isinstance(layers, (list, tuple)):
- layers = [layers]
-
- # check if arguments are valid
- assert simclr_aug is not None
-
- if imagenet is True: # assume batch_size = 1 for ImageNet
- sample_num = 1
-
- # compute features in full dataset
- model.eval()
- feats_all = {layer: [] for layer in layers} # initialize: empty list
- for i, (x, _) in enumerate(loader):
- if interp:
- x_interp = (x + last) / 2 if i > 0 else x # omit the first batch, assume batch sizes are equal
- last = x # save the last batch
- x = x_interp # use interp as current batch
-
- if imagenet is True:
- x = torch.cat(x[0], dim=0) # augmented list of x
-
- x = x.to(device) # gpu tensor
-
- # compute features in one batch
- feats_batch = {layer: [] for layer in layers} # initialize: empty list
- for seed in range(sample_num):
- set_random_seed(seed)
-
- if P.K_shift > 1:
- x_t = torch.cat([P.shift_trans(hflip(x), k) for k in range(P.K_shift)])
- else:
- x_t = x # No shifting: SimCLR
- x_t = simclr_aug(x_t)
-
- # compute augmented features
- with torch.no_grad():
- kwargs = {layer: True for layer in layers} # only forward selected layers
- _, output_aux = model(x_t, **kwargs)
-
- # add features in one batch
- for layer in layers:
- feats = output_aux[layer].cpu()
- if imagenet is False:
- feats_batch[layer] += feats.chunk(P.K_shift)
- else:
- feats_batch[layer] += [feats] # (B, d) cpu tensor
-
- # concatenate features in one batch
- for key, val in feats_batch.items():
- if imagenet:
- feats_batch[key] = torch.stack(val, dim=0) # (B, T, d)
- else:
- feats_batch[key] = torch.stack(val, dim=1) # (B, T, d)
-
- # add features in full dataset
- for layer in layers:
- feats_all[layer] += [feats_batch[layer]]
-
- # concatenate features in full dataset
- for key, val in feats_all.items():
- feats_all[key] = torch.cat(val, dim=0) # (N, T, d)
-
- # reshape order
- if imagenet is False:
- # Convert [1,2,3,4, 1,2,3,4] -> [1,1, 2,2, 3,3, 4,4]
- for key, val in feats_all.items():
- N, T, d = val.size() # T = K * T'
- val = val.view(N, -1, P.K_shift, d) # (N, T', K, d)
- val = val.transpose(2, 1) # (N, 4, T', d)
- val = val.reshape(N, T, d) # (N, T, d)
- feats_all[key] = val
-
- return feats_all
-
-
- def print_score(data_name, scores):
- quantile = np.quantile(scores, np.arange(0, 1.1, 0.1))
- print('{:18s} '.format(data_name) +
- '{:.4f} +- {:.4f} '.format(np.mean(scores), np.std(scores)) +
- ' '.join(['q{:d}: {:.4f}'.format(i * 10, quantile[i]) for i in range(11)]))
|