import os from copy import deepcopy import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import models.transform_layers as TL from utils.utils import set_random_seed, normalize from evals.evals import get_auroc device = torch.device("cuda" if torch.cuda.is_available() else "cpu") hflip = TL.HorizontalFlipLayer().to(device) def eval_ood_detection(P, model, id_loader, ood_loaders, ood_scores, train_loader=None, simclr_aug=None): auroc_dict = dict() for ood in ood_loaders.keys(): auroc_dict[ood] = dict() assert len(ood_scores) == 1 # assume single ood_score for simplicity ood_score = ood_scores[0] base_path = os.path.split(P.load_path)[0] # checkpoint directory prefix = f'{P.ood_samples}' if P.resize_fix: prefix += f'_resize_fix_{P.resize_factor}' else: prefix += f'_resize_range_{P.resize_factor}' prefix = os.path.join(base_path, f'feats_{prefix}') kwargs = { 'simclr_aug': simclr_aug, 'sample_num': P.ood_samples, 'layers': P.ood_layer, } print('Pre-compute global statistics...') feats_train = get_features(P, f'{P.dataset}_train', model, train_loader, prefix=prefix, **kwargs) # (M, T, d) P.axis = [] for f in feats_train['simclr'].chunk(P.K_shift, dim=1): axis = f.mean(dim=1) # (M, d) P.axis.append(normalize(axis, dim=1).to(device)) print('axis size: ' + ' '.join(map(lambda x: str(len(x)), P.axis))) f_sim = [f.mean(dim=1) for f in feats_train['simclr'].chunk(P.K_shift, dim=1)] # list of (M, d) f_shi = [f.mean(dim=1) for f in feats_train['shift'].chunk(P.K_shift, dim=1)] # list of (M, 4) weight_sim = [] weight_shi = [] for shi in range(P.K_shift): sim_norm = f_sim[shi].norm(dim=1) # (M) shi_mean = f_shi[shi][:, shi] # (M) weight_sim.append(1 / sim_norm.mean().item()) weight_shi.append(1 / shi_mean.mean().item()) if ood_score == 'simclr': P.weight_sim = [1] P.weight_shi = [0] elif ood_score == 'CSI': P.weight_sim = weight_sim P.weight_shi = weight_shi else: raise ValueError() print(f'weight_sim:\t' + '\t'.join(map('{:.4f}'.format, P.weight_sim))) print(f'weight_shi:\t' + '\t'.join(map('{:.4f}'.format, P.weight_shi))) print('Pre-compute features...') feats_id = get_features(P, P.dataset, model, id_loader, prefix=prefix, **kwargs) # (N, T, d) feats_ood = dict() for ood, ood_loader in ood_loaders.items(): if ood == 'interp': feats_ood[ood] = get_features(P, ood, model, id_loader, interp=True, prefix=prefix, **kwargs) else: feats_ood[ood] = get_features(P, ood, model, ood_loader, prefix=prefix, **kwargs) print(f'Compute OOD scores... (score: {ood_score})') scores_id = get_scores(P, feats_id, ood_score).numpy() scores_ood = dict() if P.one_class_idx is not None: one_class_score = [] for ood, feats in feats_ood.items(): scores_ood[ood] = get_scores(P, feats, ood_score).numpy() auroc_dict[ood][ood_score] = get_auroc(scores_id, scores_ood[ood]) if P.one_class_idx is not None: one_class_score.append(scores_ood[ood]) if P.one_class_idx is not None: one_class_score = np.concatenate(one_class_score) one_class_total = get_auroc(scores_id, one_class_score) print(f'One_class_real_mean: {one_class_total}') if P.print_score: print_score(P.dataset, scores_id) for ood, scores in scores_ood.items(): print_score(ood, scores) return auroc_dict def get_scores(P, feats_dict, ood_score): # convert to gpu tensor feats_sim = feats_dict['simclr'].to(device) feats_shi = feats_dict['shift'].to(device) N = feats_sim.size(0) # compute scores scores = [] for f_sim, f_shi in zip(feats_sim, feats_shi): f_sim = [f.mean(dim=0, keepdim=True) for f in f_sim.chunk(P.K_shift)] # list of (1, d) f_shi = [f.mean(dim=0, keepdim=True) for f in f_shi.chunk(P.K_shift)] # list of (1, 4) score = 0 for shi in range(P.K_shift): score += (f_sim[shi] * P.axis[shi]).sum(dim=1).max().item() * P.weight_sim[shi] score += f_shi[shi][:, shi].item() * P.weight_shi[shi] score = score / P.K_shift scores.append(score) scores = torch.tensor(scores) assert scores.dim() == 1 and scores.size(0) == N # (N) return scores.cpu() def get_features(P, data_name, model, loader, interp=False, prefix='', simclr_aug=None, sample_num=1, layers=('simclr', 'shift')): if not isinstance(layers, (list, tuple)): layers = [layers] # load pre-computed features if exists feats_dict = dict() # for layer in layers: # path = prefix + f'_{data_name}_{layer}.pth' # if os.path.exists(path): # feats_dict[layer] = torch.load(path) # pre-compute features and save to the path left = [layer for layer in layers if layer not in feats_dict.keys()] if len(left) > 0: _feats_dict = _get_features(P, model, loader, interp, (P.dataset == 'imagenet' or P.dataset == 'CNMC' or P.dataset == 'CNMC_grayscale'), simclr_aug, sample_num, layers=left) for layer, feats in _feats_dict.items(): path = prefix + f'_{data_name}_{layer}.pth' torch.save(_feats_dict[layer], path) feats_dict[layer] = feats # update value return feats_dict def _get_features(P, model, loader, interp=False, imagenet=False, simclr_aug=None, sample_num=1, layers=('simclr', 'shift')): if not isinstance(layers, (list, tuple)): layers = [layers] # check if arguments are valid assert simclr_aug is not None if imagenet is True: # assume batch_size = 1 for ImageNet sample_num = 1 # compute features in full dataset model.eval() feats_all = {layer: [] for layer in layers} # initialize: empty list for i, (x, _) in enumerate(loader): if interp: x_interp = (x + last) / 2 if i > 0 else x # omit the first batch, assume batch sizes are equal last = x # save the last batch x = x_interp # use interp as current batch if imagenet is True: x = torch.cat(x[0], dim=0) # augmented list of x x = x.to(device) # gpu tensor # compute features in one batch feats_batch = {layer: [] for layer in layers} # initialize: empty list for seed in range(sample_num): set_random_seed(seed) if P.K_shift > 1: x_t = torch.cat([P.shift_trans(hflip(x), k) for k in range(P.K_shift)]) else: x_t = x # No shifting: SimCLR x_t = simclr_aug(x_t) # compute augmented features with torch.no_grad(): kwargs = {layer: True for layer in layers} # only forward selected layers _, output_aux = model(x_t, **kwargs) # add features in one batch for layer in layers: feats = output_aux[layer].cpu() if imagenet is False: feats_batch[layer] += feats.chunk(P.K_shift) else: feats_batch[layer] += [feats] # (B, d) cpu tensor # concatenate features in one batch for key, val in feats_batch.items(): if imagenet: feats_batch[key] = torch.stack(val, dim=0) # (B, T, d) else: feats_batch[key] = torch.stack(val, dim=1) # (B, T, d) # add features in full dataset for layer in layers: feats_all[layer] += [feats_batch[layer]] # concatenate features in full dataset for key, val in feats_all.items(): feats_all[key] = torch.cat(val, dim=0) # (N, T, d) # reshape order if imagenet is False: # Convert [1,2,3,4, 1,2,3,4] -> [1,1, 2,2, 3,3, 4,4] for key, val in feats_all.items(): N, T, d = val.size() # T = K * T' val = val.view(N, -1, P.K_shift, d) # (N, T', K, d) val = val.transpose(2, 1) # (N, 4, T', d) val = val.reshape(N, T, d) # (N, T, d) feats_all[key] = val return feats_all def print_score(data_name, scores): quantile = np.quantile(scores, np.arange(0, 1.1, 0.1)) print('{:18s} '.format(data_name) + '{:.4f} +- {:.4f} '.format(np.mean(scores), np.std(scores)) + ' '.join(['q{:d}: {:.4f}'.format(i * 10, quantile[i]) for i in range(11)]))