202 lines
6.3 KiB
Python
202 lines
6.3 KiB
Python
import time
|
|
import itertools
|
|
|
|
import diffdist.functional as distops
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from sklearn.metrics import roc_auc_score
|
|
|
|
import models.transform_layers as TL
|
|
from utils.temperature_scaling import _ECELoss
|
|
from utils.utils import AverageMeter, set_random_seed, normalize
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
ece_criterion = _ECELoss().to(device)
|
|
|
|
|
|
def error_k(output, target, ks=(1,)):
|
|
"""Computes the precision@k for the specified values of k"""
|
|
max_k = max(ks)
|
|
batch_size = target.size(0)
|
|
|
|
_, pred = output.topk(max_k, 1, True, True)
|
|
pred = pred.t()
|
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
|
|
results = []
|
|
for k in ks:
|
|
correct_k = correct[:k].view(-1).float().sum(0)
|
|
results.append(100.0 - correct_k.mul_(100.0 / batch_size))
|
|
return results
|
|
|
|
|
|
def test_classifier(P, model, loader, steps, marginal=False, logger=None):
|
|
error_top1 = AverageMeter()
|
|
error_calibration = AverageMeter()
|
|
|
|
if logger is None:
|
|
log_ = print
|
|
else:
|
|
log_ = logger.log
|
|
|
|
# Switch to evaluate mode
|
|
mode = model.training
|
|
model.eval()
|
|
|
|
for n, (images, labels) in enumerate(loader):
|
|
batch_size = images.size(0)
|
|
|
|
images, labels = images.to(device), labels.to(device)
|
|
|
|
if marginal:
|
|
outputs = 0
|
|
for i in range(4):
|
|
rot_images = torch.rot90(images, i, (2, 3))
|
|
_, outputs_aux = model(rot_images, joint=True)
|
|
outputs += outputs_aux['joint'][:, P.n_classes * i: P.n_classes * (i + 1)] / 4.
|
|
else:
|
|
outputs = model(images)
|
|
|
|
top1, = error_k(outputs.data, labels, ks=(1,))
|
|
error_top1.update(top1.item(), batch_size)
|
|
|
|
ece = ece_criterion(outputs, labels) * 100
|
|
error_calibration.update(ece.item(), batch_size)
|
|
|
|
if n % 100 == 0:
|
|
log_('[Test %3d] [Test@1 %.3f] [ECE %.3f]' %
|
|
(n, error_top1.value, error_calibration.value))
|
|
|
|
log_(' * [Error@1 %.3f] [ECE %.3f]' %
|
|
(error_top1.average, error_calibration.average))
|
|
|
|
if logger is not None:
|
|
logger.scalar_summary('eval/clean_error', error_top1.average, steps)
|
|
logger.scalar_summary('eval/ece', error_calibration.average, steps)
|
|
|
|
model.train(mode)
|
|
|
|
return error_top1.average
|
|
|
|
|
|
def eval_ood_detection(P, model, id_loader, ood_loaders, ood_scores, train_loader=None, simclr_aug=None):
|
|
auroc_dict = dict()
|
|
for ood in ood_loaders.keys():
|
|
auroc_dict[ood] = dict()
|
|
|
|
for ood_score in ood_scores:
|
|
# compute scores for ID and OOD samples
|
|
score_func = get_ood_score_func(P, model, ood_score, simclr_aug=simclr_aug)
|
|
|
|
save_path = f'plot/score_in_{P.dataset}_{ood_score}'
|
|
if P.one_class_idx is not None:
|
|
save_path += f'_{P.one_class_idx}'
|
|
|
|
scores_id = get_scores(id_loader, score_func)
|
|
|
|
if P.save_score:
|
|
np.save(f'{save_path}.npy', scores_id)
|
|
|
|
for ood, ood_loader in ood_loaders.items():
|
|
if ood == 'interp':
|
|
scores_ood = get_scores_interp(id_loader, score_func)
|
|
auroc_dict['interp'][ood_score] = get_auroc(scores_id, scores_ood)
|
|
else:
|
|
scores_ood = get_scores(ood_loader, score_func)
|
|
auroc_dict[ood][ood_score] = get_auroc(scores_id, scores_ood)
|
|
|
|
if P.save_score:
|
|
np.save(f'{save_path}_out_{ood}.npy', scores_ood)
|
|
|
|
return auroc_dict
|
|
|
|
|
|
def get_ood_score_func(P, model, ood_score, simclr_aug=None):
|
|
def score_func(x):
|
|
return compute_ood_score(P, model, ood_score, x, simclr_aug=simclr_aug)
|
|
return score_func
|
|
|
|
|
|
def get_scores(loader, score_func):
|
|
scores = []
|
|
for i, (x, _) in enumerate(loader):
|
|
s = score_func(x.to(device))
|
|
assert s.dim() == 1 and s.size(0) == x.size(0)
|
|
|
|
scores.append(s.detach().cpu().numpy())
|
|
return np.concatenate(scores)
|
|
|
|
|
|
def get_scores_interp(loader, score_func):
|
|
scores = []
|
|
for i, (x, _) in enumerate(loader):
|
|
x_interp = (x + last) / 2 if i > 0 else x # omit the first batch, assume batch sizes are equal
|
|
last = x # save the last batch
|
|
s = score_func(x_interp.to(device))
|
|
assert s.dim() == 1 and s.size(0) == x.size(0)
|
|
|
|
scores.append(s.detach().cpu().numpy())
|
|
return np.concatenate(scores)
|
|
|
|
|
|
def get_auroc(scores_id, scores_ood):
|
|
scores = np.concatenate([scores_id, scores_ood])
|
|
labels = np.concatenate([np.ones_like(scores_id), np.zeros_like(scores_ood)])
|
|
return roc_auc_score(labels, scores)
|
|
|
|
|
|
def compute_ood_score(P, model, ood_score, x, simclr_aug=None):
|
|
model.eval()
|
|
|
|
if ood_score == 'clean_norm':
|
|
_, output_aux = model(x, penultimate=True, simclr=True)
|
|
score = output_aux[P.ood_layer].norm(dim=1)
|
|
return score
|
|
|
|
elif ood_score == 'similar':
|
|
assert simclr_aug is not None # require custom simclr augmentation
|
|
sample_num = 2 # fast evaluation
|
|
feats = get_features(model, simclr_aug, x, layer=P.ood_layer, sample_num=sample_num)
|
|
feats_avg = sum(feats) / len(feats)
|
|
|
|
scores = []
|
|
for seed in range(sample_num):
|
|
sim = torch.cosine_similarity(feats[seed], feats_avg)
|
|
scores.append(sim)
|
|
return sum(scores) / len(scores)
|
|
|
|
elif ood_score == 'baseline':
|
|
outputs, outputs_aux = model(x, penultimate=True)
|
|
scores = F.softmax(outputs, dim=1).max(dim=1)[0]
|
|
return scores
|
|
|
|
elif ood_score == 'baseline_marginalized':
|
|
|
|
total_outputs = 0
|
|
for i in range(4):
|
|
x_rot = torch.rot90(x, i, (2, 3))
|
|
outputs, outputs_aux = model(x_rot, penultimate=True, joint=True)
|
|
total_outputs += outputs_aux['joint'][:, P.n_classes * i:P.n_classes * (i + 1)]
|
|
|
|
scores = F.softmax(total_outputs / 4., dim=1).max(dim=1)[0]
|
|
return scores
|
|
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
|
|
def get_features(model, simclr_aug, x, layer='simclr', sample_num=1):
|
|
model.eval()
|
|
|
|
feats = []
|
|
for seed in range(sample_num):
|
|
set_random_seed(seed)
|
|
x_t = simclr_aug(x)
|
|
with torch.no_grad():
|
|
_, output_aux = model(x_t, penultimate=True, simclr=True, shift=True)
|
|
feats.append(output_aux[layer])
|
|
return feats
|