123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- import time
- import itertools
-
- import diffdist.functional as distops
- import numpy as np
- import torch
- import torch.distributed as dist
- import torch.nn as nn
- import torch.nn.functional as F
- from sklearn.metrics import roc_auc_score
-
- import models.transform_layers as TL
- from utils.temperature_scaling import _ECELoss
- from utils.utils import AverageMeter, set_random_seed, normalize
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- ece_criterion = _ECELoss().to(device)
-
-
- def error_k(output, target, ks=(1,)):
- """Computes the precision@k for the specified values of k"""
- max_k = max(ks)
- batch_size = target.size(0)
-
- _, pred = output.topk(max_k, 1, True, True)
- pred = pred.t()
- correct = pred.eq(target.view(1, -1).expand_as(pred))
-
- results = []
- for k in ks:
- correct_k = correct[:k].view(-1).float().sum(0)
- results.append(100.0 - correct_k.mul_(100.0 / batch_size))
- return results
-
-
- def test_classifier(P, model, loader, steps, marginal=False, logger=None):
- error_top1 = AverageMeter()
- error_calibration = AverageMeter()
-
- if logger is None:
- log_ = print
- else:
- log_ = logger.log
-
- # Switch to evaluate mode
- mode = model.training
- model.eval()
-
- for n, (images, labels) in enumerate(loader):
- batch_size = images.size(0)
-
- images, labels = images.to(device), labels.to(device)
-
- if marginal:
- outputs = 0
- for i in range(4):
- rot_images = torch.rot90(images, i, (2, 3))
- _, outputs_aux = model(rot_images, joint=True)
- outputs += outputs_aux['joint'][:, P.n_classes * i: P.n_classes * (i + 1)] / 4.
- else:
- outputs = model(images)
-
- top1, = error_k(outputs.data, labels, ks=(1,))
- error_top1.update(top1.item(), batch_size)
-
- ece = ece_criterion(outputs, labels) * 100
- error_calibration.update(ece.item(), batch_size)
-
- if n % 100 == 0:
- log_('[Test %3d] [Test@1 %.3f] [ECE %.3f]' %
- (n, error_top1.value, error_calibration.value))
-
- log_(' * [Error@1 %.3f] [ECE %.3f]' %
- (error_top1.average, error_calibration.average))
-
- if logger is not None:
- logger.scalar_summary('eval/clean_error', error_top1.average, steps)
- logger.scalar_summary('eval/ece', error_calibration.average, steps)
-
- model.train(mode)
-
- return error_top1.average
-
-
- def eval_ood_detection(P, model, id_loader, ood_loaders, ood_scores, train_loader=None, simclr_aug=None):
- auroc_dict = dict()
- for ood in ood_loaders.keys():
- auroc_dict[ood] = dict()
-
- for ood_score in ood_scores:
- # compute scores for ID and OOD samples
- score_func = get_ood_score_func(P, model, ood_score, simclr_aug=simclr_aug)
-
- save_path = f'plot/score_in_{P.dataset}_{ood_score}'
- if P.one_class_idx is not None:
- save_path += f'_{P.one_class_idx}'
-
- scores_id = get_scores(id_loader, score_func)
-
- if P.save_score:
- np.save(f'{save_path}.npy', scores_id)
-
- for ood, ood_loader in ood_loaders.items():
- if ood == 'interp':
- scores_ood = get_scores_interp(id_loader, score_func)
- auroc_dict['interp'][ood_score] = get_auroc(scores_id, scores_ood)
- else:
- scores_ood = get_scores(ood_loader, score_func)
- auroc_dict[ood][ood_score] = get_auroc(scores_id, scores_ood)
-
- if P.save_score:
- np.save(f'{save_path}_out_{ood}.npy', scores_ood)
-
- return auroc_dict
-
-
- def get_ood_score_func(P, model, ood_score, simclr_aug=None):
- def score_func(x):
- return compute_ood_score(P, model, ood_score, x, simclr_aug=simclr_aug)
- return score_func
-
-
- def get_scores(loader, score_func):
- scores = []
- for i, (x, _) in enumerate(loader):
- s = score_func(x.to(device))
- assert s.dim() == 1 and s.size(0) == x.size(0)
-
- scores.append(s.detach().cpu().numpy())
- return np.concatenate(scores)
-
-
- def get_scores_interp(loader, score_func):
- scores = []
- for i, (x, _) in enumerate(loader):
- x_interp = (x + last) / 2 if i > 0 else x # omit the first batch, assume batch sizes are equal
- last = x # save the last batch
- s = score_func(x_interp.to(device))
- assert s.dim() == 1 and s.size(0) == x.size(0)
-
- scores.append(s.detach().cpu().numpy())
- return np.concatenate(scores)
-
-
- def get_auroc(scores_id, scores_ood):
- scores = np.concatenate([scores_id, scores_ood])
- labels = np.concatenate([np.ones_like(scores_id), np.zeros_like(scores_ood)])
- return roc_auc_score(labels, scores)
-
-
- def compute_ood_score(P, model, ood_score, x, simclr_aug=None):
- model.eval()
-
- if ood_score == 'clean_norm':
- _, output_aux = model(x, penultimate=True, simclr=True)
- score = output_aux[P.ood_layer].norm(dim=1)
- return score
-
- elif ood_score == 'similar':
- assert simclr_aug is not None # require custom simclr augmentation
- sample_num = 2 # fast evaluation
- feats = get_features(model, simclr_aug, x, layer=P.ood_layer, sample_num=sample_num)
- feats_avg = sum(feats) / len(feats)
-
- scores = []
- for seed in range(sample_num):
- sim = torch.cosine_similarity(feats[seed], feats_avg)
- scores.append(sim)
- return sum(scores) / len(scores)
-
- elif ood_score == 'baseline':
- outputs, outputs_aux = model(x, penultimate=True)
- scores = F.softmax(outputs, dim=1).max(dim=1)[0]
- return scores
-
- elif ood_score == 'baseline_marginalized':
-
- total_outputs = 0
- for i in range(4):
- x_rot = torch.rot90(x, i, (2, 3))
- outputs, outputs_aux = model(x_rot, penultimate=True, joint=True)
- total_outputs += outputs_aux['joint'][:, P.n_classes * i:P.n_classes * (i + 1)]
-
- scores = F.softmax(total_outputs / 4., dim=1).max(dim=1)[0]
- return scores
-
- else:
- raise NotImplementedError()
-
-
- def get_features(model, simclr_aug, x, layer='simclr', sample_num=1):
- model.eval()
-
- feats = []
- for seed in range(sample_num):
- set_random_seed(seed)
- x_t = simclr_aug(x)
- with torch.no_grad():
- _, output_aux = model(x_t, penultimate=True, simclr=True, shift=True)
- feats.append(output_aux[layer])
- return feats
|