# CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances | |||||
Official PyTorch implementation of | |||||
["**CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances**"]( | |||||
https://arxiv.org/abs/2007.08176) (NeurIPS 2020) by | |||||
[Jihoon Tack*](https://jihoontack.github.io), | |||||
[Sangwoo Mo*](https://sites.google.com/view/sangwoomo), | |||||
[Jongheon Jeong](https://sites.google.com/view/jongheonj), | |||||
and [Jinwoo Shin](http://alinlab.kaist.ac.kr/shin.html). | |||||
<p align="center"> | |||||
<img src=figures/shifting_transformations.png width="900"> | |||||
</p> | |||||
## 1. Requirements | |||||
### Environments | |||||
Currently, requires following packages | |||||
- python 3.6+ | |||||
- torch 1.4+ | |||||
- torchvision 0.5+ | |||||
- CUDA 10.1+ | |||||
- scikit-learn 0.22+ | |||||
- tensorboard 2.0+ | |||||
- [torchlars](https://github.com/kakaobrain/torchlars) == 0.1.2 | |||||
- [pytorch-gradual-warmup-lr](https://github.com/ildoonet/pytorch-gradual-warmup-lr) packages | |||||
- [apex](https://github.com/NVIDIA/apex) == 0.1 | |||||
- [diffdist](https://github.com/ag14774/diffdist) == 0.1 | |||||
### Datasets | |||||
For CIFAR, please download the following datasets to `~/data`. | |||||
* [LSUN_resize](https://www.dropbox.com/s/moqh2wh8696c3yl/LSUN_resize.tar.gz), | |||||
[ImageNet_resize](https://www.dropbox.com/s/kp3my3412u5k9rl/Imagenet_resize.tar.gz) | |||||
* [LSUN_fix](https://drive.google.com/file/d/1KVWj9xpHfVwGcErH5huVujk9snhEGOxE/view?usp=sharing), | |||||
[ImageNet_fix](https://drive.google.com/file/d/1sO_-noq10mmziB1ECDyNhD5T4u5otyKA/view?usp=sharing) | |||||
For ImageNet-30, please download the following datasets to `~/data`. | |||||
* [ImageNet-30-train](https://drive.google.com/file/d/1B5c39Fc3haOPzlehzmpTLz6xLtGyKEy4/view), | |||||
[ImageNet-30-test](https://drive.google.com/file/d/13xzVuQMEhSnBRZr-YaaO08coLU2dxAUq/view) | |||||
* [CUB-200](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html), | |||||
[Stanford Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/), | |||||
[Oxford Pets](https://www.robots.ox.ac.uk/~vgg/data/pets/), | |||||
[Oxford flowers](https://www.robots.ox.ac.uk/~vgg/data/flowers/), | |||||
[Food-101](https://www.kaggle.com/dansbecker/food-101), | |||||
[Places-365](http://data.csail.mit.edu/places/places365/val_256.tar), | |||||
[Caltech-256](https://www.kaggle.com/jessicali9530/caltech256), | |||||
[DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/) | |||||
For Food-101, remove hotdog class to avoid overlap. | |||||
## 2. Training | |||||
Currently, all code examples are assuming distributed launch with 4 multi GPUs. | |||||
To run the code with single GPU, remove `-m torch.distributed.launch --nproc_per_node=4`. | |||||
### Unlabeled one-class & multi-class | |||||
To train unlabeled one-class & multi-class models in the paper, run this command: | |||||
```train | |||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py --dataset <DATASET> --model <NETWORK> --mode simclr_CSI --shift_trans_type rotation --batch_size 32 --one_class_idx <One-Class-Index> | |||||
``` | |||||
> Option --one_class_idx denotes the in-distribution of one-class training. | |||||
> For multi-class training, set --one_class_idx as None. | |||||
> To run SimCLR simply change --mode to simclr. | |||||
> Total batch size should be 512 = 4 (GPU) * 32 (--batch_size option) * 4 (cardinality of shifted transformation set). | |||||
### Labeled multi-class | |||||
To train labeled multi-class model (confidence calibrated classifier) in the paper, run this command: | |||||
```train | |||||
# Representation train | |||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py --dataset <DATASET> --model <NETWORK> --mode sup_simclr_CSI --shift_trans_type rotation --batch_size 32 --epoch 700 | |||||
# Linear layer train | |||||
python train.py --mode sup_CSI_linear --dataset <DATASET> --model <NETWORK> --batch_size 32 --epoch 100 --shift_trans_type rotation --load_path <MODEL_PATH> | |||||
``` | |||||
> To run SupCLR simply change --mode to sup_simclr, sup_linear for representation training and linear layer training respectively. | |||||
> Total batch size should be same as above. Currently only supports rotation for shifted transformation. | |||||
## 3. Evaluation | |||||
We provide the checkpoint of the CSI pre-trained model. Download the checkpoint from the following link: | |||||
- One-class CIFAR-10: [ResNet-18](https://drive.google.com/drive/folders/1z02i0G_lzrZe0NwpH-tnjpO8pYHV7mE9?usp=sharing) | |||||
- Unlabeled (multi-class) CIFAR-10: [ResNet-18](https://drive.google.com/file/d/1yUq6Si6hWaMa1uYyLDHk0A4BrPIa8ECV/view?usp=sharing) | |||||
- Unlabeled (multi-class) ImageNet-30: [ResNet-18](https://drive.google.com/file/d/1KucQWSik8RyoJgU-fz8XLmCWhvMOP7fT/view?usp=sharing) | |||||
- Labeled (multi-class) CIFAR-10: [ResNet-18](https://drive.google.com/file/d/1rW2-0MJEzPHLb_PAW-LvCivHt-TkDpRO/view?usp=sharing) | |||||
### Unlabeled one-class & multi-class | |||||
To evaluate my model on unlabeled one-class & multi-class out-of-distribution (OOD) detection setting, run this command: | |||||
```eval | |||||
python eval.py --mode ood_pre --dataset <DATASET> --model <NETWORK> --ood_score CSI --shift_trans_type rotation --print_score --ood_samples 10 --resize_factor 0.54 --resize_fix --one_class_idx <One-Class-Index> --load_path <MODEL_PATH> | |||||
``` | |||||
> Option --one_class_idx denotes the in-distribution of one-class evaluation. | |||||
> For multi-class evaluation, set --one_class_idx as None. | |||||
> The resize_factor & resize fix option fix the cropping size of RandomResizedCrop(). | |||||
> For SimCLR evaluation, change --ood_score to simclr. | |||||
### Labeled multi-class | |||||
To evaluate my model on labeled multi-class accuracy, ECE, OOD detection setting, run this command: | |||||
```eval | |||||
# OOD AUROC | |||||
python eval.py --mode ood --ood_score baseline_marginalized --print_score --dataset <DATASET> --model <NETWORK> --shift_trans_type rotation --load_path <MODEL_PATH> | |||||
# Accuray & ECE | |||||
python eval.py --mode test_marginalized_acc --dataset <DATASET> --model <NETWORK> --shift_trans_type rotation --load_path <MODEL_PATH> | |||||
``` | |||||
> This option is for marginalized inference. | |||||
> For single inference (also used for SupCLR) change --ood_score baseline in first command, | |||||
> and --mode test_acc in second command. | |||||
## 4. Results | |||||
Our model achieves the following performance on: | |||||
### One-Class Out-of-Distribution Detection | |||||
| Method | Dataset | AUROC (Mean) | | |||||
| --------------|------------------ | --------------| | |||||
| SimCLR | CIFAR-10-OC | 87.9% | | |||||
| Rot+Trans | CIFAR-10-OC | 90.0% | | |||||
| CSI (ours) | CIFAR-10-OC | 94.3% | | |||||
We only show CIFAR-10 one-class result in this repo. For other setting, please see our paper. | |||||
### Unlabeled Multi-Class Out-of-Distribution Detection | |||||
| Method | Dataset | OOD Dataset | AUROC (Mean) | | |||||
| --------------|------------------ |---------------|--------------| | |||||
| Rot+Trans | CIFAR-10 | CIFAR-100 | 82.5% | | |||||
| CSI (ours) | CIFAR-10 | CIFAR-100 | 89.3% | | |||||
We only show CIFAR-10 to CIFAR-100 OOD detection result in this repo. For other OOD dataset results, see our paper. | |||||
### Labeled Multi-Class Result | |||||
| Method | Dataset | OOD Dataset | Acc | ECE | AUROC (Mean) | | |||||
| ---------------- |------------------ |---------------|-------|-------|--------------| | |||||
| SupCLR | CIFAR-10 | CIFAR-100 | 93.9% | 5.54% | 88.3% | | |||||
| CSI (ours) | CIFAR-10 | CIFAR-100 | 94.8% | 4.24% | 90.6% | | |||||
| CSI-ensem (ours) | CIFAR-10 | CIFAR-100 | 96.0% | 3.64% | 92.3% | | |||||
We only show CIFAR-10 with CIFAR-100 as OOD in this repo. For other dataset results, please see our paper. | |||||
## 5. New OOD dataset | |||||
<p align="center"> | |||||
<img src=figures/fixed_ood_benchmarks.png width="600"> | |||||
</p> | |||||
We find that current benchmark datasets for OOD detection, are visually far from in-distribution datasets (e.g. CIFAR). | |||||
To address this issue, we provide new datasets for OOD detection evaluation: | |||||
[LSUN_fix](https://drive.google.com/file/d/1KVWj9xpHfVwGcErH5huVujk9snhEGOxE/view?usp=sharing), | |||||
[ImageNet_fix](https://drive.google.com/file/d/1sO_-noq10mmziB1ECDyNhD5T4u5otyKA/view?usp=sharing). | |||||
See the above figure for the visualization of current benchmark and our dataset. | |||||
To generate OOD datasets, run the following codes inside the `./datasets` folder: | |||||
```OOD dataset generation | |||||
# ImageNet FIX generation code | |||||
python imagenet_fix_preprocess.py | |||||
# LSUN FIX generation code | |||||
python lsun_fix_preprocess.py | |||||
``` | |||||
## Citation | |||||
``` | |||||
@inproceedings{tack2020csi, | |||||
title={CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances}, | |||||
author={Jihoon Tack and Sangwoo Mo and Jongheon Jeong and Jinwoo Shin}, | |||||
booktitle={Advances in Neural Information Processing Systems}, | |||||
year={2020} | |||||
} | |||||
``` |
""" | |||||
References: | |||||
- https://github.com/PyTorchLightning/PyTorch-Lightning-Bolts/blob/master/pl_bolts/optimizers/lars_scheduling.py | |||||
- https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py | |||||
- https://arxiv.org/pdf/1708.03888.pdf | |||||
- https://github.com/noahgolmant/pytorch-lars/blob/master/lars.py | |||||
""" | |||||
import torch | |||||
from .wrapper import OptimWrapper | |||||
# from torchlars._adaptive_lr import compute_adaptive_lr # Impossible to build extensions | |||||
__all__ = ["LARS"] | |||||
class LARS(OptimWrapper): | |||||
"""Implements 'LARS (Layer-wise Adaptive Rate Scaling)'__ as Optimizer a | |||||
:class:`~torch.optim.Optimizer` wrapper. | |||||
__ : https://arxiv.org/abs/1708.03888 | |||||
Wraps an arbitrary optimizer like :class:`torch.optim.SGD` to use LARS. If | |||||
you want to the same performance obtained with small-batch training when | |||||
you use large-batch training, LARS will be helpful:: | |||||
Args: | |||||
optimizer (Optimizer): | |||||
optimizer to wrap | |||||
eps (float, optional): | |||||
epsilon to help with numerical stability while calculating the | |||||
adaptive learning rate | |||||
trust_coef (float, optional): | |||||
trust coefficient for calculating the adaptive learning rate | |||||
Example:: | |||||
base_optimizer = optim.SGD(model.parameters(), lr=0.1) | |||||
optimizer = LARS(optimizer=base_optimizer) | |||||
output = model(input) | |||||
loss = loss_fn(output, target) | |||||
loss.backward() | |||||
optimizer.step() | |||||
""" | |||||
def __init__(self, optimizer, trust_coef=0.02, clip=True, eps=1e-8): | |||||
if eps < 0.0: | |||||
raise ValueError("invalid epsilon value: , %f" % eps) | |||||
if trust_coef < 0.0: | |||||
raise ValueError("invalid trust coefficient: %f" % trust_coef) | |||||
self.optim = optimizer | |||||
self.eps = eps | |||||
self.trust_coef = trust_coef | |||||
self.clip = clip | |||||
def __getstate__(self): | |||||
self.optim.__get | |||||
lars_dict = {} | |||||
lars_dict["trust_coef"] = self.trust_coef | |||||
lars_dict["clip"] = self.clip | |||||
lars_dict["eps"] = self.eps | |||||
return (self.optim, lars_dict) | |||||
def __setstate__(self, state): | |||||
self.optim, lars_dict = state | |||||
self.trust_coef = lars_dict["trust_coef"] | |||||
self.clip = lars_dict["clip"] | |||||
self.eps = lars_dict["eps"] | |||||
@torch.no_grad() | |||||
def step(self, closure=None): | |||||
weight_decays = [] | |||||
for group in self.optim.param_groups: | |||||
weight_decay = group.get("weight_decay", 0) | |||||
weight_decays.append(weight_decay) | |||||
# reset weight decay | |||||
group["weight_decay"] = 0 | |||||
# update the parameters | |||||
for p in group["params"]: | |||||
if p.grad is not None: | |||||
self.update_p(p, group, weight_decay) | |||||
# update the optimizer | |||||
self.optim.step(closure=closure) | |||||
# return weight decay control to optimizer | |||||
for group_idx, group in enumerate(self.optim.param_groups): | |||||
group["weight_decay"] = weight_decays[group_idx] | |||||
def update_p(self, p, group, weight_decay): | |||||
# calculate new norms | |||||
p_norm = torch.norm(p.data) | |||||
g_norm = torch.norm(p.grad.data) | |||||
if p_norm != 0 and g_norm != 0: | |||||
# calculate new lr | |||||
divisor = g_norm + p_norm * weight_decay + self.eps | |||||
adaptive_lr = (self.trust_coef * p_norm) / divisor | |||||
# clip lr | |||||
if self.clip: | |||||
adaptive_lr = min(adaptive_lr / group["lr"], 1) | |||||
# update params with clipped lr | |||||
p.grad.data += weight_decay * p.data | |||||
p.grad.data *= adaptive_lr | |||||
from torch.optim import SGD | |||||
from pylot.util import delegates, separate_kwargs | |||||
class SGDLARS(LARS): | |||||
@delegates(to=LARS.__init__) | |||||
@delegates(to=SGD.__init__, keep=True, but=["eps", "trust_coef"]) | |||||
def __init__(self, params, **kwargs): | |||||
sgd_kwargs, lars_kwargs = separate_kwargs(kwargs, SGD.__init__) | |||||
optim = SGD(params, **sgd_kwargs) | |||||
super().__init__(optim, **lars_kwargs) |
from argparse import ArgumentParser | |||||
def parse_args(default=False): | |||||
"""Command-line argument parser for training.""" | |||||
parser = ArgumentParser(description='Pytorch implementation of CSI') | |||||
parser.add_argument('--dataset', help='Dataset', | |||||
choices=['cifar10', 'cifar100', 'imagenet', 'CNMC', 'CNMC_grayscale'], type=str) | |||||
parser.add_argument('--one_class_idx', help='None: multi-class, Not None: one-class', | |||||
default=None, type=int) | |||||
parser.add_argument('--model', help='Model', | |||||
choices=['resnet18', 'resnet18_imagenet'], type=str) | |||||
parser.add_argument('--mode', help='Training mode', | |||||
default='simclr', type=str) | |||||
parser.add_argument('--simclr_dim', help='Dimension of simclr layer', | |||||
default=128, type=int) | |||||
parser.add_argument('--shift_trans_type', help='shifting transformation type', default='none', | |||||
choices=['rotation', 'cutperm', 'blur', 'randpers', 'sharp', 'blur_randpers', | |||||
'blur_sharp', 'randpers_sharp', 'blur_randpers_sharp', 'noise', 'none'], type=str) | |||||
parser.add_argument("--local_rank", type=int, | |||||
default=0, help='Local rank for distributed learning') | |||||
parser.add_argument('--resume_path', help='Path to the resume checkpoint', | |||||
default=None, type=str) | |||||
parser.add_argument('--load_path', help='Path to the loading checkpoint', | |||||
default=None, type=str) | |||||
parser.add_argument("--no_strict", help='Do not strictly load state_dicts', | |||||
action='store_true') | |||||
parser.add_argument('--suffix', help='Suffix for the log dir', | |||||
default=None, type=str) | |||||
parser.add_argument('--error_step', help='Epoch steps to compute errors', | |||||
default=5, type=int) | |||||
parser.add_argument('--save_step', help='Epoch steps to save models', | |||||
default=10, type=int) | |||||
##### Training Configurations ##### | |||||
parser.add_argument('--epochs', help='Epochs', | |||||
default=1000, type=int) | |||||
parser.add_argument('--optimizer', help='Optimizer', | |||||
choices=['sgd', 'lars'], | |||||
default='lars', type=str) | |||||
parser.add_argument('--lr_scheduler', help='Learning rate scheduler', | |||||
choices=['step_decay', 'cosine'], | |||||
default='cosine', type=str) | |||||
parser.add_argument('--warmup', help='Warm-up epochs', | |||||
default=10, type=int) | |||||
parser.add_argument('--lr_init', help='Initial learning rate', | |||||
default=1e-1, type=float) | |||||
parser.add_argument('--weight_decay', help='Weight decay', | |||||
default=1e-6, type=float) | |||||
parser.add_argument('--batch_size', help='Batch size', | |||||
default=128, type=int) | |||||
parser.add_argument('--test_batch_size', help='Batch size for test loader', | |||||
default=100, type=int) | |||||
parser.add_argument('--blur_sigma', help='Distortion grade', | |||||
default=2.0, type=float) | |||||
parser.add_argument('--color_distort', help='Color distortion grade', | |||||
default=0.5, type=float) | |||||
parser.add_argument('--distortion_scale', help='Perspective distortion grade', | |||||
default=0.6, type=float) | |||||
parser.add_argument('--sharpness_factor', help='Sharpening or blurring factor of image. ' | |||||
'Can be any non negative number. 0 gives a blurred image, ' | |||||
'1 gives the original image while 2 increases the sharpness ' | |||||
'by a factor of 2.', | |||||
default=2, type=float) | |||||
parser.add_argument('--noise_mean', help='mean', | |||||
default=0, type=float) | |||||
parser.add_argument('--noise_std', help='std', | |||||
default=0.3, type=float) | |||||
##### Objective Configurations ##### | |||||
parser.add_argument('--sim_lambda', help='Weight for SimCLR loss', | |||||
default=1.0, type=float) | |||||
parser.add_argument('--temperature', help='Temperature for similarity', | |||||
default=0.5, type=float) | |||||
##### Evaluation Configurations ##### | |||||
parser.add_argument("--ood_dataset", help='Datasets for OOD detection', | |||||
default=None, nargs="*", type=str) | |||||
parser.add_argument("--ood_score", help='score function for OOD detection', | |||||
default=['norm_mean'], nargs="+", type=str) | |||||
parser.add_argument("--ood_layer", help='layer for OOD scores', | |||||
choices=['penultimate', 'simclr', 'shift'], | |||||
default=['simclr', 'shift'], nargs="+", type=str) | |||||
parser.add_argument("--ood_samples", help='number of samples to compute OOD score', | |||||
default=1, type=int) | |||||
parser.add_argument("--ood_batch_size", help='batch size to compute OOD score', | |||||
default=100, type=int) | |||||
parser.add_argument("--resize_factor", help='resize scale is sampled from [resize_factor, 1.0]', | |||||
default=0.08, type=float) | |||||
parser.add_argument("--resize_fix", help='resize scale is fixed to resize_factor (not (resize_factor, 1.0])', | |||||
action='store_true') | |||||
parser.add_argument("--print_score", help='print quantiles of ood score', | |||||
action='store_true') | |||||
parser.add_argument("--save_score", help='save ood score for plotting histogram', | |||||
action='store_true') | |||||
##### Process configuration option ##### | |||||
parser.add_argument("--proc_step", help='choose process to initiate.', | |||||
choices=['eval', 'train'], | |||||
default=None, type=str) | |||||
parser.add_argument("--res", help='resolution of dataset', | |||||
default="32px", type=str) | |||||
if default: | |||||
return parser.parse_args('') # empty string | |||||
else: | |||||
return parser.parse_args() |
from copy import deepcopy | |||||
import torch | |||||
import torch.nn as nn | |||||
from torch.utils.data import DataLoader | |||||
from common.common import parse_args | |||||
import models.classifier as C | |||||
from datasets import get_dataset, get_superclass_list, get_subclass_dataset | |||||
P = parse_args() | |||||
### Set torch device ### | |||||
P.n_gpus = torch.cuda.device_count() | |||||
assert P.n_gpus <= 1 # no multi GPU | |||||
P.multi_gpu = False | |||||
if torch.cuda.is_available(): | |||||
torch.cuda.set_device(P.local_rank) | |||||
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu") | |||||
### Initialize dataset ### | |||||
ood_eval = P.mode == 'ood_pre' | |||||
if P.dataset == 'imagenet' and ood_eval or P.dataset == 'CNMC' and ood_eval or P.dataset == 'CNMC_grayscale' and ood_eval: | |||||
P.batch_size = 1 | |||||
P.test_batch_size = 1 | |||||
train_set, test_set, image_size, n_classes = get_dataset(P, dataset=P.dataset, eval=ood_eval) | |||||
P.image_size = image_size | |||||
P.n_classes = n_classes | |||||
if P.one_class_idx is not None: | |||||
cls_list = get_superclass_list(P.dataset) | |||||
P.n_superclasses = len(cls_list) | |||||
full_test_set = deepcopy(test_set) # test set of full classes | |||||
train_set = get_subclass_dataset(train_set, classes=cls_list[P.one_class_idx]) | |||||
test_set = get_subclass_dataset(test_set, classes=cls_list[P.one_class_idx]) | |||||
kwargs = {'pin_memory': False, 'num_workers': 2} | |||||
train_loader = DataLoader(train_set, shuffle=True, batch_size=P.batch_size, **kwargs) | |||||
test_loader = DataLoader(test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs) | |||||
if P.ood_dataset is None: | |||||
if P.one_class_idx is not None: | |||||
P.ood_dataset = list(range(P.n_superclasses)) | |||||
P.ood_dataset.pop(P.one_class_idx) | |||||
elif P.dataset == 'cifar10': | |||||
P.ood_dataset = ['svhn', 'lsun_resize', 'imagenet_resize', 'lsun_fix', 'imagenet_fix', 'cifar100', 'interp'] | |||||
elif P.dataset == 'imagenet': | |||||
P.ood_dataset = ['cub', 'stanford_dogs', 'flowers102', 'places365', 'food_101', 'caltech_256', 'dtd', 'pets'] | |||||
ood_test_loader = dict() | |||||
for ood in P.ood_dataset: | |||||
if ood == 'interp': | |||||
ood_test_loader[ood] = None # dummy loader | |||||
continue | |||||
if P.one_class_idx is not None: | |||||
ood_test_set = get_subclass_dataset(full_test_set, classes=cls_list[ood]) | |||||
ood = f'one_class_{ood}' # change save name | |||||
else: | |||||
ood_test_set = get_dataset(P, dataset=ood, test_only=True, image_size=P.image_size, eval=ood_eval) | |||||
ood_test_loader[ood] = DataLoader(ood_test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs) | |||||
### Initialize model ### | |||||
simclr_aug = C.get_simclr_augmentation(P, image_size=P.image_size).to(device) | |||||
P.shift_trans, P.K_shift = C.get_shift_module(P, eval=True) | |||||
P.shift_trans = P.shift_trans.to(device) | |||||
model = C.get_classifier(P.model, n_classes=P.n_classes).to(device) | |||||
model = C.get_shift_classifer(model, P.K_shift).to(device) | |||||
criterion = nn.CrossEntropyLoss().to(device) | |||||
if P.load_path is not None: | |||||
checkpoint = torch.load(P.load_path) | |||||
model.load_state_dict(checkpoint, strict=not P.no_strict) |
from copy import deepcopy | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.optim as optim | |||||
import torch.optim.lr_scheduler as lr_scheduler | |||||
from torch.utils.data import DataLoader | |||||
from common.common import parse_args | |||||
import models.classifier as C | |||||
from datasets import get_dataset, get_superclass_list, get_subclass_dataset | |||||
from utils.utils import load_checkpoint | |||||
P = parse_args() | |||||
### Set torch device ### | |||||
if torch.cuda.is_available(): | |||||
torch.cuda.set_device(P.local_rank) | |||||
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu") | |||||
P.n_gpus = torch.cuda.device_count() | |||||
if P.n_gpus > 1: | |||||
import apex | |||||
import torch.distributed as dist | |||||
from torch.utils.data.distributed import DistributedSampler | |||||
P.multi_gpu = True | |||||
torch.distributed.init_process_group( | |||||
'nccl', | |||||
init_method='env://', | |||||
world_size=P.n_gpus, | |||||
rank=P.local_rank, | |||||
) | |||||
else: | |||||
P.multi_gpu = False | |||||
### only use one ood_layer while training | |||||
P.ood_layer = P.ood_layer[0] | |||||
### Initialize dataset ### | |||||
train_set, test_set, image_size, n_classes = get_dataset(P, dataset=P.dataset) | |||||
P.image_size = image_size | |||||
P.n_classes = n_classes | |||||
if P.one_class_idx is not None: | |||||
cls_list = get_superclass_list(P.dataset) | |||||
P.n_superclasses = len(cls_list) | |||||
full_test_set = deepcopy(test_set) # test set of full classes | |||||
train_set = get_subclass_dataset(train_set, classes=cls_list[P.one_class_idx]) | |||||
test_set = get_subclass_dataset(test_set, classes=cls_list[P.one_class_idx]) | |||||
kwargs = {'pin_memory': False, 'num_workers': 2} | |||||
if P.multi_gpu: | |||||
train_sampler = DistributedSampler(train_set, num_replicas=P.n_gpus, rank=P.local_rank) | |||||
test_sampler = DistributedSampler(test_set, num_replicas=P.n_gpus, rank=P.local_rank) | |||||
train_loader = DataLoader(train_set, sampler=train_sampler, batch_size=P.batch_size, **kwargs) | |||||
test_loader = DataLoader(test_set, sampler=test_sampler, batch_size=P.test_batch_size, **kwargs) | |||||
else: | |||||
train_loader = DataLoader(train_set, shuffle=True, batch_size=P.batch_size, **kwargs) | |||||
test_loader = DataLoader(test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs) | |||||
if P.ood_dataset is None: | |||||
if P.one_class_idx is not None: | |||||
P.ood_dataset = list(range(P.n_superclasses)) | |||||
P.ood_dataset.pop(P.one_class_idx) | |||||
elif P.dataset == 'cifar10': | |||||
P.ood_dataset = ['svhn', 'lsun_resize', 'imagenet_resize', 'lsun_fix', 'imagenet_fix', 'cifar100', 'interp'] | |||||
elif P.dataset == 'imagenet': | |||||
P.ood_dataset = ['cub', 'stanford_dogs', 'flowers102'] | |||||
ood_test_loader = dict() | |||||
for ood in P.ood_dataset: | |||||
if ood == 'interp': | |||||
ood_test_loader[ood] = None # dummy loader | |||||
continue | |||||
if P.one_class_idx is not None: | |||||
ood_test_set = get_subclass_dataset(full_test_set, classes=cls_list[ood]) | |||||
ood = f'one_class_{ood}' # change save name | |||||
else: | |||||
ood_test_set = get_dataset(P, dataset=ood, test_only=True, image_size=P.image_size) | |||||
if P.multi_gpu: | |||||
ood_sampler = DistributedSampler(ood_test_set, num_replicas=P.n_gpus, rank=P.local_rank) | |||||
ood_test_loader[ood] = DataLoader(ood_test_set, sampler=ood_sampler, batch_size=P.test_batch_size, **kwargs) | |||||
else: | |||||
ood_test_loader[ood] = DataLoader(ood_test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs) | |||||
### Initialize model ### | |||||
simclr_aug = C.get_simclr_augmentation(P, image_size=P.image_size).to(device) | |||||
P.shift_trans, P.K_shift = C.get_shift_module(P, eval=True) | |||||
P.shift_trans = P.shift_trans.to(device) | |||||
model = C.get_classifier(P.model, n_classes=P.n_classes).to(device) | |||||
model = C.get_shift_classifer(model, P.K_shift).to(device) | |||||
criterion = nn.CrossEntropyLoss().to(device) | |||||
if P.optimizer == 'sgd': | |||||
optimizer = optim.SGD(model.parameters(), lr=P.lr_init, momentum=0.9, weight_decay=P.weight_decay) | |||||
lr_decay_gamma = 0.1 | |||||
elif P.optimizer == 'lars': | |||||
from torchlars import LARS | |||||
base_optimizer = optim.SGD(model.parameters(), lr=P.lr_init, momentum=0.9, weight_decay=P.weight_decay) | |||||
optimizer = LARS(base_optimizer, eps=1e-8, trust_coef=0.001) | |||||
lr_decay_gamma = 0.1 | |||||
else: | |||||
raise NotImplementedError() | |||||
if P.lr_scheduler == 'cosine': | |||||
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, P.epochs) | |||||
elif P.lr_scheduler == 'step_decay': | |||||
milestones = [int(0.5 * P.epochs), int(0.75 * P.epochs)] | |||||
scheduler = lr_scheduler.MultiStepLR(optimizer, gamma=lr_decay_gamma, milestones=milestones) | |||||
else: | |||||
raise NotImplementedError() | |||||
from training.scheduler import GradualWarmupScheduler | |||||
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=10.0, total_epoch=P.warmup, after_scheduler=scheduler) | |||||
if P.resume_path is not None: | |||||
resume = True | |||||
model_state, optim_state, config = load_checkpoint(P.resume_path, mode='last') | |||||
model.load_state_dict(model_state, strict=not P.no_strict) | |||||
optimizer.load_state_dict(optim_state) | |||||
start_epoch = config['epoch'] | |||||
best = config['best'] | |||||
error = 100.0 | |||||
else: | |||||
resume = False | |||||
start_epoch = 1 | |||||
best = 100.0 | |||||
error = 100.0 | |||||
if P.mode == 'sup_linear' or P.mode == 'sup_CSI_linear': | |||||
assert P.load_path is not None | |||||
checkpoint = torch.load(P.load_path) | |||||
model.load_state_dict(checkpoint, strict=not P.no_strict) | |||||
if P.multi_gpu: | |||||
simclr_aug = apex.parallel.DistributedDataParallel(simclr_aug, delay_allreduce=True) | |||||
model = apex.parallel.convert_syncbn_model(model) | |||||
model = apex.parallel.DistributedDataParallel(model, delay_allreduce=True) |
from datasets.datasets import get_dataset, get_superclass_list, get_subclass_dataset | |||||
import os | |||||
import numpy as np | |||||
import torch | |||||
from torch.utils.data.dataset import Subset | |||||
from torchvision import datasets, transforms | |||||
from utils.utils import set_random_seed | |||||
DATA_PATH = '~/data/' | |||||
IMAGENET_PATH = '~/data/ImageNet' | |||||
CNMC_PATH = r'~/data/CSI/CNMC_orig' | |||||
CNMC_GRAY_PATH = r'~/data/CSI/CNMC_orig_gray' | |||||
CNMC_ROT4_PATH = r'~/data/CSI/CNMC_rotated_4' | |||||
CIFAR10_SUPERCLASS = list(range(10)) # one class | |||||
IMAGENET_SUPERCLASS = list(range(30)) # one class | |||||
CNMC_SUPERCLASS = list(range(2)) # one class | |||||
STD_RES = 450 | |||||
STD_CENTER_CROP = 300 | |||||
CIFAR100_SUPERCLASS = [ | |||||
[4, 31, 55, 72, 95], | |||||
[1, 33, 67, 73, 91], | |||||
[54, 62, 70, 82, 92], | |||||
[9, 10, 16, 29, 61], | |||||
[0, 51, 53, 57, 83], | |||||
[22, 25, 40, 86, 87], | |||||
[5, 20, 26, 84, 94], | |||||
[6, 7, 14, 18, 24], | |||||
[3, 42, 43, 88, 97], | |||||
[12, 17, 38, 68, 76], | |||||
[23, 34, 49, 60, 71], | |||||
[15, 19, 21, 32, 39], | |||||
[35, 63, 64, 66, 75], | |||||
[27, 45, 77, 79, 99], | |||||
[2, 11, 36, 46, 98], | |||||
[28, 30, 44, 78, 93], | |||||
[37, 50, 65, 74, 80], | |||||
[47, 52, 56, 59, 96], | |||||
[8, 13, 48, 58, 90], | |||||
[41, 69, 81, 85, 89], | |||||
] | |||||
class MultiDataTransform(object): | |||||
def __init__(self, transform): | |||||
self.transform1 = transform | |||||
self.transform2 = transform | |||||
def __call__(self, sample): | |||||
x1 = self.transform1(sample) | |||||
x2 = self.transform2(sample) | |||||
return x1, x2 | |||||
class MultiDataTransformList(object): | |||||
def __init__(self, transform, clean_trasform, sample_num): | |||||
self.transform = transform | |||||
self.clean_transform = clean_trasform | |||||
self.sample_num = sample_num | |||||
def __call__(self, sample): | |||||
set_random_seed(0) | |||||
sample_list = [] | |||||
for i in range(self.sample_num): | |||||
sample_list.append(self.transform(sample)) | |||||
return sample_list, self.clean_transform(sample) | |||||
def get_transform(image_size=None): | |||||
# Note: data augmentation is implemented in the layers | |||||
# Hence, we only define the identity transformation here | |||||
if image_size: # use pre-specified image size | |||||
train_transform = transforms.Compose([ | |||||
transforms.Resize((image_size[0], image_size[1])), | |||||
transforms.RandomHorizontalFlip(), | |||||
transforms.ToTensor(), | |||||
]) | |||||
test_transform = transforms.Compose([ | |||||
transforms.Resize((image_size[0], image_size[1])), | |||||
transforms.ToTensor(), | |||||
]) | |||||
else: # use default image size | |||||
train_transform = transforms.Compose([ | |||||
transforms.ToTensor(), | |||||
]) | |||||
test_transform = transforms.ToTensor() | |||||
return train_transform, test_transform | |||||
def get_subset_with_len(dataset, length, shuffle=False): | |||||
set_random_seed(0) | |||||
dataset_size = len(dataset) | |||||
index = np.arange(dataset_size) | |||||
if shuffle: | |||||
np.random.shuffle(index) | |||||
index = torch.from_numpy(index[0:length]) | |||||
subset = Subset(dataset, index) | |||||
assert len(subset) == length | |||||
return subset | |||||
def get_transform_imagenet(): | |||||
train_transform = transforms.Compose([ | |||||
transforms.Resize(256), | |||||
transforms.RandomResizedCrop(224), | |||||
transforms.RandomHorizontalFlip(), | |||||
transforms.ToTensor(), | |||||
]) | |||||
test_transform = transforms.Compose([ | |||||
transforms.Resize(256), | |||||
transforms.CenterCrop(224), | |||||
transforms.ToTensor(), | |||||
]) | |||||
train_transform = MultiDataTransform(train_transform) | |||||
return train_transform, test_transform | |||||
def get_transform_cnmc(res, center_crop_size): | |||||
train_transform = transforms.Compose([ | |||||
transforms.Resize(res), | |||||
transforms.CenterCrop(center_crop_size), | |||||
transforms.RandomVerticalFlip(), | |||||
transforms.RandomHorizontalFlip(), | |||||
transforms.ToTensor(), | |||||
]) | |||||
test_transform = transforms.Compose([ | |||||
transforms.Resize(res), | |||||
transforms.CenterCrop(center_crop_size), | |||||
transforms.ToTensor(), | |||||
]) | |||||
train_transform = MultiDataTransform(train_transform) | |||||
return train_transform, test_transform | |||||
def get_dataset(P, dataset, test_only=False, image_size=None, download=False, eval=False): | |||||
if P.res != '': | |||||
res = int(P.res.replace('px', '')) | |||||
size_factor = int(STD_RES/res) # always remove same portion | |||||
center_crop_size = int(STD_CENTER_CROP/size_factor) # remove black border | |||||
if dataset in ['CNMC', 'CNMC_grayscale', 'CNMC_ROT4_PATH']: | |||||
if eval: | |||||
train_transform, test_transform = get_simclr_eval_transform_cnmc(P.ood_samples, | |||||
P.resize_factor, P.resize_fix, res, center_crop_size) | |||||
else: | |||||
train_transform, test_transform = get_transform_cnmc(res, center_crop_size) | |||||
elif dataset in ['imagenet', 'cub', 'stanford_dogs', 'flowers102', | |||||
'places365', 'food_101', 'caltech_256', 'dtd', 'pets']: | |||||
if eval: | |||||
train_transform, test_transform = get_simclr_eval_transform_imagenet(P.ood_samples, | |||||
P.resize_factor, P.resize_fix) | |||||
else: | |||||
train_transform, test_transform = get_transform_imagenet() | |||||
else: | |||||
train_transform, test_transform = get_transform(image_size=image_size) | |||||
if dataset == 'CNMC': | |||||
image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3 | |||||
n_classes = 2 | |||||
train_dir = os.path.join(CNMC_PATH, '0_training') | |||||
test_dir = os.path.join(CNMC_PATH, '1_validation') | |||||
train_set = datasets.ImageFolder(train_dir, transform=train_transform) | |||||
test_set = datasets.ImageFolder(test_dir, transform=test_transform) | |||||
elif dataset == 'CNMC_grayscale': | |||||
image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3 | |||||
n_classes = 2 | |||||
train_dir = os.path.join(CNMC_GRAY_PATH, '0_training') | |||||
test_dir = os.path.join(CNMC_GRAY_PATH, '1_validation') | |||||
train_set = datasets.ImageFolder(train_dir, transform=train_transform) | |||||
test_set = datasets.ImageFolder(test_dir, transform=test_transform) | |||||
elif dataset == 'cifar10': | |||||
image_size = (32, 32, 3) | |||||
n_classes = 10 | |||||
train_set = datasets.CIFAR10(DATA_PATH, train=True, download=download, transform=train_transform) | |||||
test_set = datasets.CIFAR10(DATA_PATH, train=False, download=download, transform=test_transform) | |||||
elif dataset == 'cifar100': | |||||
image_size = (32, 32, 3) | |||||
n_classes = 100 | |||||
train_set = datasets.CIFAR100(DATA_PATH, train=True, download=download, transform=train_transform) | |||||
test_set = datasets.CIFAR100(DATA_PATH, train=False, download=download, transform=test_transform) | |||||
elif dataset == 'svhn': | |||||
assert test_only and image_size is not None | |||||
test_set = datasets.SVHN(DATA_PATH, split='test', download=download, transform=test_transform) | |||||
elif dataset == 'lsun_resize': | |||||
assert test_only and image_size is not None | |||||
test_dir = os.path.join(DATA_PATH, 'LSUN_resize') | |||||
test_set = datasets.ImageFolder(test_dir, transform=test_transform) | |||||
elif dataset == 'lsun_fix': | |||||
assert test_only and image_size is not None | |||||
test_dir = os.path.join(DATA_PATH, 'LSUN_fix') | |||||
test_set = datasets.ImageFolder(test_dir, transform=test_transform) | |||||
elif dataset == 'imagenet_resize': | |||||
assert test_only and image_size is not None | |||||
test_dir = os.path.join(DATA_PATH, 'Imagenet_resize') | |||||
test_set = datasets.ImageFolder(test_dir, transform=test_transform) | |||||
elif dataset == 'imagenet_fix': | |||||
assert test_only and image_size is not None | |||||
test_dir = os.path.join(DATA_PATH, 'Imagenet_fix') | |||||
test_set = datasets.ImageFolder(test_dir, transform=test_transform) | |||||
elif dataset == 'imagenet': | |||||
image_size = (224, 224, 3) | |||||
n_classes = 30 | |||||
train_dir = os.path.join(IMAGENET_PATH, 'one_class_train') | |||||
test_dir = os.path.join(IMAGENET_PATH, 'one_class_test') | |||||
train_set = datasets.ImageFolder(train_dir, transform=train_transform) | |||||
test_set = datasets.ImageFolder(test_dir, transform=test_transform) | |||||
elif dataset == 'stanford_dogs': | |||||
assert test_only and image_size is not None | |||||
test_dir = os.path.join(DATA_PATH, 'stanford_dogs') | |||||
test_set = datasets.ImageFolder(test_dir, transform=test_transform) | |||||
test_set = get_subset_with_len(test_set, length=3000, shuffle=True) | |||||
elif dataset == 'cub': | |||||
assert test_only and image_size is not None | |||||
test_dir = os.path.join(DATA_PATH, 'cub200') | |||||
test_set = datasets.ImageFolder(test_dir, transform=test_transform) | |||||
test_set = get_subset_with_len(test_set, length=3000, shuffle=True) | |||||
elif dataset == 'flowers102': | |||||
assert test_only and image_size is not None | |||||
test_dir = os.path.join(DATA_PATH, 'flowers102') | |||||
test_set = datasets.ImageFolder(test_dir, transform=test_transform) | |||||
test_set = get_subset_with_len(test_set, length=3000, shuffle=True) | |||||
elif dataset == 'places365': | |||||
assert test_only and image_size is not None | |||||
test_dir = os.path.join(DATA_PATH, 'places365') | |||||
test_set = datasets.ImageFolder(test_dir, transform=test_transform) | |||||
test_set = get_subset_with_len(test_set, length=3000, shuffle=True) | |||||
elif dataset == 'food_101': | |||||
assert test_only and image_size is not None | |||||
test_dir = os.path.join(DATA_PATH, 'food-101', 'images') | |||||
test_set = datasets.ImageFolder(test_dir, transform=test_transform) | |||||
test_set = get_subset_with_len(test_set, length=3000, shuffle=True) | |||||
elif dataset == 'caltech_256': | |||||
assert test_only and image_size is not None | |||||
test_dir = os.path.join(DATA_PATH, 'caltech-256') | |||||
test_set = datasets.ImageFolder(test_dir, transform=test_transform) | |||||
test_set = get_subset_with_len(test_set, length=3000, shuffle=True) | |||||
elif dataset == 'dtd': | |||||
assert test_only and image_size is not None | |||||
test_dir = os.path.join(DATA_PATH, 'dtd', 'images') | |||||
test_set = datasets.ImageFolder(test_dir, transform=test_transform) | |||||
test_set = get_subset_with_len(test_set, length=3000, shuffle=True) | |||||
elif dataset == 'pets': | |||||
assert test_only and image_size is not None | |||||
test_dir = os.path.join(DATA_PATH, 'pets') | |||||
test_set = datasets.ImageFolder(test_dir, transform=test_transform) | |||||
test_set = get_subset_with_len(test_set, length=3000, shuffle=True) | |||||
else: | |||||
raise NotImplementedError() | |||||
if test_only: | |||||
return test_set | |||||
else: | |||||
return train_set, test_set, image_size, n_classes | |||||
def get_superclass_list(dataset): | |||||
if dataset == 'CNMC': | |||||
return CNMC_SUPERCLASS | |||||
if dataset == 'CNMC_grayscale': | |||||
return CNMC_SUPERCLASS | |||||
elif dataset == 'cifar10': | |||||
return CIFAR10_SUPERCLASS | |||||
elif dataset == 'cifar100': | |||||
return CIFAR100_SUPERCLASS | |||||
elif dataset == 'imagenet': | |||||
return IMAGENET_SUPERCLASS | |||||
else: | |||||
raise NotImplementedError() | |||||
def get_subclass_dataset(dataset, classes): | |||||
if not isinstance(classes, list): | |||||
classes = [classes] | |||||
indices = [] | |||||
for idx, tgt in enumerate(dataset.targets): | |||||
if tgt in classes: | |||||
indices.append(idx) | |||||
dataset = Subset(dataset, indices) | |||||
return dataset | |||||
def get_simclr_eval_transform_imagenet(sample_num, resize_factor, resize_fix): | |||||
resize_scale = (resize_factor, 1.0) # resize scaling factor | |||||
if resize_fix: # if resize_fix is True, use same scale | |||||
resize_scale = (resize_factor, resize_factor) | |||||
transform = transforms.Compose([ | |||||
transforms.Resize(256), | |||||
transforms.RandomResizedCrop(224, scale=resize_scale), | |||||
transforms.RandomHorizontalFlip(), | |||||
transforms.ToTensor(), | |||||
]) | |||||
clean_trasform = transforms.Compose([ | |||||
transforms.Resize(256), | |||||
transforms.CenterCrop(224), | |||||
transforms.ToTensor(), | |||||
]) | |||||
transform = MultiDataTransformList(transform, clean_trasform, sample_num) | |||||
return transform, transform | |||||
def get_simclr_eval_transform_cnmc(sample_num, resize_factor, resize_fix, res, center_crop_size): | |||||
resize_scale = (resize_factor, 1.0) # resize scaling factor | |||||
if resize_fix: # if resize_fix is True, use same scale | |||||
resize_scale = (resize_factor, resize_factor) | |||||
transform = transforms.Compose([ | |||||
transforms.Resize(res), | |||||
transforms.CenterCrop(center_crop_size), | |||||
transforms.RandomVerticalFlip(), | |||||
transforms.RandomHorizontalFlip(), | |||||
transforms.ToTensor(), | |||||
]) | |||||
clean_trasform = transforms.Compose([ | |||||
transforms.Resize(res), | |||||
transforms.CenterCrop(center_crop_size), | |||||
transforms.ToTensor(), | |||||
]) | |||||
transform = MultiDataTransformList(transform, clean_trasform, sample_num) | |||||
return transform, transform | |||||
import os | |||||
import time | |||||
import random | |||||
import cv2 | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from torchvision import datasets, transforms | |||||
from torch.utils.data import DataLoader | |||||
from torchvision.utils import save_image | |||||
from datasets import get_subclass_dataset | |||||
def set_random_seed(seed): | |||||
random.seed(seed) | |||||
np.random.seed(seed) | |||||
torch.manual_seed(seed) | |||||
torch.cuda.manual_seed(seed) | |||||
IMAGENET_PATH = '~/data/ImageNet' | |||||
check = time.time() | |||||
transform = transforms.Compose([ | |||||
transforms.Resize(256), | |||||
transforms.CenterCrop(256), | |||||
transforms.Resize(32), | |||||
transforms.ToTensor(), | |||||
]) | |||||
# remove airliner(1), ambulance(2), parking_meter(18), schooner(22) since similar class exist in CIFAR-10 | |||||
class_idx_list = list(range(30)) | |||||
remove_idx_list = [1, 2, 18, 22] | |||||
for remove_idx in remove_idx_list: | |||||
class_idx_list.remove(remove_idx) | |||||
set_random_seed(0) | |||||
train_dir = os.path.join(IMAGENET_PATH, 'one_class_train') | |||||
Imagenet_set = datasets.ImageFolder(train_dir, transform=transform) | |||||
Imagenet_set = get_subclass_dataset(Imagenet_set, class_idx_list) | |||||
Imagenet_dataloader = DataLoader(Imagenet_set, batch_size=100, shuffle=True, pin_memory=False) | |||||
total_test_image = None | |||||
for n, (test_image, target) in enumerate(Imagenet_dataloader): | |||||
if n == 0: | |||||
total_test_image = test_image | |||||
else: | |||||
total_test_image = torch.cat((total_test_image, test_image), dim=0).cpu() | |||||
if total_test_image.size(0) >= 10000: | |||||
break | |||||
print (f'Preprocessing time {time.time()-check}') | |||||
if not os.path.exists('./Imagenet_fix'): | |||||
os.mkdir('./Imagenet_fix') | |||||
check = time.time() | |||||
for i in range(10000): | |||||
save_image(total_test_image[i], f'Imagenet_fix/correct_resize_{i}.png') | |||||
print (f'Saving time {time.time()-check}') | |||||
import os | |||||
import time | |||||
import random | |||||
import numpy as np | |||||
import torch | |||||
from torchvision import datasets, transforms | |||||
from torch.utils.data import DataLoader | |||||
from torchvision.utils import save_image | |||||
def set_random_seed(seed): | |||||
random.seed(seed) | |||||
np.random.seed(seed) | |||||
torch.manual_seed(seed) | |||||
torch.cuda.manual_seed(seed) | |||||
check = time.time() | |||||
transform = transforms.Compose([ | |||||
transforms.Resize(256), | |||||
transforms.CenterCrop(256), | |||||
transforms.Resize(32), | |||||
transforms.ToTensor(), | |||||
]) | |||||
set_random_seed(0) | |||||
LSUN_class_list = ['bedroom', 'bridge', 'church_outdoor', 'classroom', | |||||
'conference_room', 'dining_room', 'kitchen', 'living_room', 'restaurant', 'tower'] | |||||
total_test_image_all_class = [] | |||||
for LSUN_class in LSUN_class_list: | |||||
LSUN_set = datasets.LSUN('~/data/lsun/', classes=LSUN_class + '_train', transform=transform) | |||||
LSUN_loader = DataLoader(LSUN_set, batch_size=100, shuffle=True, pin_memory=False) | |||||
total_test_image = None | |||||
for n, (test_image, _) in enumerate(LSUN_loader): | |||||
if n == 0: | |||||
total_test_image = test_image | |||||
else: | |||||
total_test_image = torch.cat((total_test_image, test_image), dim=0).cpu() | |||||
if total_test_image.size(0) >= 1000: | |||||
break | |||||
total_test_image_all_class.append(total_test_image) | |||||
total_test_image_all_class = torch.cat(total_test_image_all_class, dim=0) | |||||
print (f'Preprocessing time {time.time()-check}') | |||||
if not os.path.exists('./LSUN_fix'): | |||||
os.mkdir('./LSUN_fix') | |||||
check = time.time() | |||||
for i in range(10000): | |||||
save_image(total_test_image_all_class[i], f'LSUN_fix/correct_resize_{i}.png') | |||||
print (f'Saving time {time.time()-check}') | |||||
import re | |||||
import matplotlib.pyplot as plt | |||||
PATH = r'C:\Users\feokt\PycharmProjects\CSI\CSI\logs' | |||||
def postprocess_data(log: list): | |||||
for pth in log: | |||||
loss_sim = [] | |||||
loss_shift = [] | |||||
with open(PATH + pth) as f: | |||||
lines = f.readlines() | |||||
for line in lines: | |||||
# line = '[2022-01-31 20:40:23.947855] [DONE] [Time 0.179] [Data 0.583] [LossC 0.000000] [LossSim 4.024234] [LossShift 0.065126]' | |||||
part = re.search('\[DONE\]', line) | |||||
if part is not None: | |||||
l_sim = re.search('(\[LossSim.[0-9]*.[0-9]*\])', line).group() | |||||
if l_sim is not None: | |||||
loss_sim.append(float(re.search('(\s[0-9].*[0-9])', l_sim).group())) | |||||
l_shift = re.search('(\[LossShift.[0-9]*.[0-9]*\])', line).group() | |||||
if l_shift is not None: | |||||
loss_shift.append(float(re.search('(\s[0-9].*[0-9])', l_shift).group())) | |||||
loss = [loss_sim[i] + loss_shift[i] for i in range(len(loss_sim))] | |||||
plt.ylabel("loss") | |||||
plt.xlabel("epoch") | |||||
plt.title("Loss over epochs") | |||||
plt.plot(list(range(1, 101)), loss) | |||||
for idx in range(len(log)): | |||||
log[idx] = log[idx][38:] | |||||
plt.legend(log) | |||||
plt.grid() | |||||
#plt.plot(list(range(1, 101)), loss_sim) | |||||
#plt.plot(list(range(1, 101)), loss_shift) | |||||
plt.show() | |||||
import csv | |||||
import os | |||||
from PIL import Image | |||||
from torchvision import transforms | |||||
from torchvision.utils import save_image | |||||
import torch | |||||
def transform_image(img_in, target_dir, transformation, suffix): | |||||
""" | |||||
Transforms an image according to provided transformation. | |||||
Parameters: | |||||
img_in (path): Image to transform | |||||
target_dir (path): Destination path | |||||
transformation (callable): Transformation to be applied | |||||
suffix (str): Suffix of resulting image. | |||||
Returns: | |||||
binary_sum (str): Binary string of the sum of a and b | |||||
""" | |||||
if suffix == 'rot': | |||||
im = Image.open(img_in) | |||||
im = im.rotate(270) | |||||
tensor = transforms.ToTensor()(im) | |||||
save_image(tensor, target_dir + os.sep + suffix + '.jpg') | |||||
elif suffix == 'sobel': | |||||
im = Image.open(img_in) | |||||
tensor = transforms.ToTensor()(im) | |||||
sobel_filter = torch.tensor([[1., 2., 1.], [0., 0., 0.], [-1., -2., -1.]]) | |||||
f = sobel_filter.expand(1, 3, 3, 3) | |||||
tensor = torch.conv2d(tensor, f, stride=1, padding=1 ) | |||||
save_image(tensor, target_dir + os.sep + suffix + '.jpg') | |||||
elif suffix == 'noise': | |||||
im = Image.open(img_in) | |||||
tensor = transforms.ToTensor()(im) | |||||
tensor = tensor + (torch.randn(tensor.size()) * 0.2 + 0) | |||||
save_image(tensor, target_dir + os.sep + suffix + '.jpg') | |||||
elif suffix == 'cutout': | |||||
print("asd") | |||||
else: | |||||
im = Image.open(img_in) | |||||
im_trans = transformation(im) | |||||
im_trans.save(target_dir + os.sep + suffix + '.jpg') | |||||
def sort_and_rename_images(excel_path: str): | |||||
"""Renames images and sorts them according to csv.""" | |||||
base_dir = excel_path.rsplit(os.sep, 1)[0] | |||||
dir_all = base_dir + os.sep + 'all' | |||||
if not os.path.isdir(dir_all): | |||||
os.mkdir(dir_all) | |||||
dir_hem = base_dir + os.sep + 'hem' | |||||
if not os.path.isdir(dir_hem): | |||||
os.mkdir(dir_hem) | |||||
with open(excel_path, mode='r') as file: | |||||
csv_file = csv.reader(file) | |||||
for lines in csv_file: | |||||
print(lines) | |||||
if lines[2] == '1': | |||||
os.rename(base_dir + os.sep + lines[1], dir_all + os.sep + lines[0]) | |||||
elif lines[2] == '0': | |||||
os.rename(base_dir + os.sep + lines[1], dir_hem + os.sep + lines[0]) | |||||
def drop_color_channels(source_dir, target_dir, rgb): | |||||
"""Rotates all images in in source dir.""" | |||||
if rgb == 0: | |||||
suffix = "red_only" | |||||
drop_1 = 1 | |||||
drop_2 = 2 | |||||
elif rgb == 1: | |||||
suffix = "green_only" | |||||
drop_1 = 0 | |||||
drop_2 = 2 | |||||
elif rgb == 2: | |||||
suffix = "blue_only" | |||||
drop_1 = 0 | |||||
drop_2 = 1 | |||||
elif rgb == 3: | |||||
suffix = "no_red" | |||||
drop_1 = 0 | |||||
elif rgb == 4: | |||||
suffix = "no_green" | |||||
drop_1 = 1 | |||||
elif rgb == 5: | |||||
suffix = "no_blue" | |||||
drop_1 = 2 | |||||
else: | |||||
suffix = "" | |||||
print("Invalid RGB-channel") | |||||
if suffix != "": | |||||
dirs = os.listdir(source_dir) | |||||
for item in dirs: | |||||
if os.path.isfile(source_dir + os.sep + item): | |||||
im = Image.open(source_dir + os.sep + item) | |||||
tensor = transforms.ToTensor()(im) | |||||
tensor[drop_1, :, :] = 0 | |||||
if rgb < 3: | |||||
tensor[drop_2, :, :] = 0 | |||||
save_image(tensor, target_dir + os.sep + item, 'bmp') | |||||
def rotate_images(target_dir, source_dir, rotate, theta): | |||||
"""Rotates all images in in source dir.""" | |||||
dirs = os.listdir(source_dir) | |||||
for item in dirs: | |||||
if os.path.isfile(source_dir + os.sep + item): | |||||
for i in range(0, rotate): | |||||
im = Image.open(source_dir + os.sep + item) | |||||
im = im.rotate(i*theta) | |||||
tensor = transforms.ToTensor()(im) | |||||
save_image(tensor, target_dir + os.sep + str(i) + '_' + item, 'bmp') | |||||
def grayscale_image(source_dir, target_dir): | |||||
"""Grayscale transforms all images in path.""" | |||||
t = transforms.Grayscale() | |||||
dirs = os.listdir(source_dir) | |||||
if not os.path.isdir(target_dir): | |||||
os.mkdir(target_dir) | |||||
for item in dirs: | |||||
if os.path.isfile(source_dir + os.sep + item): | |||||
im = Image.open(source_dir + os.sep + item).convert('RGB') | |||||
im_resize = t(im) | |||||
tensor = transforms.ToTensor()(im_resize) | |||||
padding = torch.zeros(1, tensor.shape[1], tensor.shape[2]) | |||||
tensor = torch.cat((tensor, padding), 0) | |||||
im_resize.save(target_dir + os.sep + item, 'bmp') | |||||
def resize(source_dir): | |||||
"""Rotates all images in in source dir.""" | |||||
t = transforms.Compose([transforms.Resize((128, 128))]) | |||||
dirs = os.listdir(source_dir) | |||||
target_dir = source_dir + os.sep + 'resized' | |||||
if not os.path.isdir(target_dir): | |||||
os.mkdir(target_dir) | |||||
for item in dirs: | |||||
if os.path.isfile(source_dir + os.sep + item): | |||||
im = Image.open(source_dir + os.sep + item) | |||||
im_resize = t(im) | |||||
im_resize.save(source_dir + os.sep + 'resized' + os.sep + item, 'bmp') | |||||
def crop_image(source_dir): | |||||
"""Center Crops all images in path.""" | |||||
t = transforms.CenterCrop((224, 224)) | |||||
dirs = os.listdir(source_dir) | |||||
target_dir = source_dir + os.sep + 'cropped' | |||||
if not os.path.isdir(target_dir): | |||||
os.mkdir(target_dir) | |||||
for item in dirs: | |||||
if os.path.isfile(source_dir + os.sep + item): | |||||
im = Image.open(source_dir + os.sep + item) | |||||
im_resize = t(im, ) | |||||
im_resize.save(source_dir + os.sep + 'cropped' + os.sep + item, 'bmp') | |||||
def mk_dirs(target_dir): | |||||
dir_0 = target_dir + r"\fold_0" | |||||
dir_1 = target_dir + r"\fold_1" | |||||
dir_2 = target_dir + r"\fold_2" | |||||
dir_3 = target_dir + r"\phase2" | |||||
dir_4 = target_dir + r"\phase3" | |||||
dir_0_all = dir_0 + r"\all" | |||||
dir_0_hem = dir_0 + r"\hem" | |||||
dir_1_all = dir_1 + r"\all" | |||||
dir_1_hem = dir_1 + r"\hem" | |||||
dir_2_all = dir_2 + r"\all" | |||||
dir_2_hem = dir_2 + r"\hem" | |||||
if not os.path.isdir(dir_0): | |||||
os.mkdir(dir_0) | |||||
if not os.path.isdir(dir_1): | |||||
os.mkdir(dir_1) | |||||
if not os.path.isdir(dir_2): | |||||
os.mkdir(dir_2) | |||||
if not os.path.isdir(dir_3): | |||||
os.mkdir(dir_3) | |||||
if not os.path.isdir(dir_4): | |||||
os.mkdir(dir_4) | |||||
if not os.path.isdir(dir_0_all): | |||||
os.mkdir(dir_0_all) | |||||
if not os.path.isdir(dir_0_hem): | |||||
os.mkdir(dir_0_hem) | |||||
if not os.path.isdir(dir_1_all): | |||||
os.mkdir(dir_1_all) | |||||
if not os.path.isdir(dir_1_hem): | |||||
os.mkdir(dir_1_hem) | |||||
if not os.path.isdir(dir_2_all): | |||||
os.mkdir(dir_2_all) | |||||
if not os.path.isdir(dir_2_hem): | |||||
os.mkdir(dir_2_hem) | |||||
return dir_0_all, dir_0_hem, dir_1_all, dir_1_hem, dir_2_all, dir_2_hem, dir_3, dir_4 |
from common.eval import * | |||||
def main(): | |||||
model.eval() | |||||
if P.mode == 'test_acc': | |||||
from evals import test_classifier | |||||
with torch.no_grad(): | |||||
error = test_classifier(P, model, test_loader, 0, logger=None) | |||||
elif P.mode == 'test_marginalized_acc': | |||||
from evals import test_classifier | |||||
with torch.no_grad(): | |||||
error = test_classifier(P, model, test_loader, 0, marginal=True, logger=None) | |||||
elif P.mode in ['ood', 'ood_pre']: | |||||
if P.mode == 'ood': | |||||
from evals import eval_ood_detection | |||||
else: | |||||
from evals.ood_pre import eval_ood_detection | |||||
with torch.no_grad(): | |||||
auroc_dict = eval_ood_detection(P, model, test_loader, ood_test_loader, P.ood_score, | |||||
train_loader=train_loader, simclr_aug=simclr_aug) | |||||
if P.one_class_idx is not None: | |||||
mean_dict = dict() | |||||
for ood_score in P.ood_score: | |||||
mean = 0 | |||||
for ood in auroc_dict.keys(): | |||||
mean += auroc_dict[ood][ood_score] | |||||
mean_dict[ood_score] = mean / len(auroc_dict.keys()) | |||||
auroc_dict['one_class_mean'] = mean_dict | |||||
bests = [] | |||||
for ood in auroc_dict.keys(): | |||||
message = '' | |||||
best_auroc = 0 | |||||
for ood_score, auroc in auroc_dict[ood].items(): | |||||
message += '[%s %s %.4f] ' % (ood, ood_score, auroc) | |||||
if auroc > best_auroc: | |||||
best_auroc = auroc | |||||
message += '[%s %s %.4f] ' % (ood, 'best', best_auroc) | |||||
if P.print_score: | |||||
print(message) | |||||
bests.append(best_auroc) | |||||
bests = map('{:.4f}'.format, bests) | |||||
print('\t'.join(bests)) | |||||
else: | |||||
raise NotImplementedError() | |||||
if __name__ == '__main__': | |||||
main() |
from evals.evals import test_classifier, eval_ood_detection |
import time | |||||
import itertools | |||||
import diffdist.functional as distops | |||||
import numpy as np | |||||
import torch | |||||
import torch.distributed as dist | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from sklearn.metrics import roc_auc_score | |||||
import models.transform_layers as TL | |||||
from utils.temperature_scaling import _ECELoss | |||||
from utils.utils import AverageMeter, set_random_seed, normalize | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
ece_criterion = _ECELoss().to(device) | |||||
def error_k(output, target, ks=(1,)): | |||||
"""Computes the precision@k for the specified values of k""" | |||||
max_k = max(ks) | |||||
batch_size = target.size(0) | |||||
_, pred = output.topk(max_k, 1, True, True) | |||||
pred = pred.t() | |||||
correct = pred.eq(target.view(1, -1).expand_as(pred)) | |||||
results = [] | |||||
for k in ks: | |||||
correct_k = correct[:k].view(-1).float().sum(0) | |||||
results.append(100.0 - correct_k.mul_(100.0 / batch_size)) | |||||
return results | |||||
def test_classifier(P, model, loader, steps, marginal=False, logger=None): | |||||
error_top1 = AverageMeter() | |||||
error_calibration = AverageMeter() | |||||
if logger is None: | |||||
log_ = print | |||||
else: | |||||
log_ = logger.log | |||||
# Switch to evaluate mode | |||||
mode = model.training | |||||
model.eval() | |||||
for n, (images, labels) in enumerate(loader): | |||||
batch_size = images.size(0) | |||||
images, labels = images.to(device), labels.to(device) | |||||
if marginal: | |||||
outputs = 0 | |||||
for i in range(4): | |||||
rot_images = torch.rot90(images, i, (2, 3)) | |||||
_, outputs_aux = model(rot_images, joint=True) | |||||
outputs += outputs_aux['joint'][:, P.n_classes * i: P.n_classes * (i + 1)] / 4. | |||||
else: | |||||
outputs = model(images) | |||||
top1, = error_k(outputs.data, labels, ks=(1,)) | |||||
error_top1.update(top1.item(), batch_size) | |||||
ece = ece_criterion(outputs, labels) * 100 | |||||
error_calibration.update(ece.item(), batch_size) | |||||
if n % 100 == 0: | |||||
log_('[Test %3d] [Test@1 %.3f] [ECE %.3f]' % | |||||
(n, error_top1.value, error_calibration.value)) | |||||
log_(' * [Error@1 %.3f] [ECE %.3f]' % | |||||
(error_top1.average, error_calibration.average)) | |||||
if logger is not None: | |||||
logger.scalar_summary('eval/clean_error', error_top1.average, steps) | |||||
logger.scalar_summary('eval/ece', error_calibration.average, steps) | |||||
model.train(mode) | |||||
return error_top1.average | |||||
def eval_ood_detection(P, model, id_loader, ood_loaders, ood_scores, train_loader=None, simclr_aug=None): | |||||
auroc_dict = dict() | |||||
for ood in ood_loaders.keys(): | |||||
auroc_dict[ood] = dict() | |||||
for ood_score in ood_scores: | |||||
# compute scores for ID and OOD samples | |||||
score_func = get_ood_score_func(P, model, ood_score, simclr_aug=simclr_aug) | |||||
save_path = f'plot/score_in_{P.dataset}_{ood_score}' | |||||
if P.one_class_idx is not None: | |||||
save_path += f'_{P.one_class_idx}' | |||||
scores_id = get_scores(id_loader, score_func) | |||||
if P.save_score: | |||||
np.save(f'{save_path}.npy', scores_id) | |||||
for ood, ood_loader in ood_loaders.items(): | |||||
if ood == 'interp': | |||||
scores_ood = get_scores_interp(id_loader, score_func) | |||||
auroc_dict['interp'][ood_score] = get_auroc(scores_id, scores_ood) | |||||
else: | |||||
scores_ood = get_scores(ood_loader, score_func) | |||||
auroc_dict[ood][ood_score] = get_auroc(scores_id, scores_ood) | |||||
if P.save_score: | |||||
np.save(f'{save_path}_out_{ood}.npy', scores_ood) | |||||
return auroc_dict | |||||
def get_ood_score_func(P, model, ood_score, simclr_aug=None): | |||||
def score_func(x): | |||||
return compute_ood_score(P, model, ood_score, x, simclr_aug=simclr_aug) | |||||
return score_func | |||||
def get_scores(loader, score_func): | |||||
scores = [] | |||||
for i, (x, _) in enumerate(loader): | |||||
s = score_func(x.to(device)) | |||||
assert s.dim() == 1 and s.size(0) == x.size(0) | |||||
scores.append(s.detach().cpu().numpy()) | |||||
return np.concatenate(scores) | |||||
def get_scores_interp(loader, score_func): | |||||
scores = [] | |||||
for i, (x, _) in enumerate(loader): | |||||
x_interp = (x + last) / 2 if i > 0 else x # omit the first batch, assume batch sizes are equal | |||||
last = x # save the last batch | |||||
s = score_func(x_interp.to(device)) | |||||
assert s.dim() == 1 and s.size(0) == x.size(0) | |||||
scores.append(s.detach().cpu().numpy()) | |||||
return np.concatenate(scores) | |||||
def get_auroc(scores_id, scores_ood): | |||||
scores = np.concatenate([scores_id, scores_ood]) | |||||
labels = np.concatenate([np.ones_like(scores_id), np.zeros_like(scores_ood)]) | |||||
return roc_auc_score(labels, scores) | |||||
def compute_ood_score(P, model, ood_score, x, simclr_aug=None): | |||||
model.eval() | |||||
if ood_score == 'clean_norm': | |||||
_, output_aux = model(x, penultimate=True, simclr=True) | |||||
score = output_aux[P.ood_layer].norm(dim=1) | |||||
return score | |||||
elif ood_score == 'similar': | |||||
assert simclr_aug is not None # require custom simclr augmentation | |||||
sample_num = 2 # fast evaluation | |||||
feats = get_features(model, simclr_aug, x, layer=P.ood_layer, sample_num=sample_num) | |||||
feats_avg = sum(feats) / len(feats) | |||||
scores = [] | |||||
for seed in range(sample_num): | |||||
sim = torch.cosine_similarity(feats[seed], feats_avg) | |||||
scores.append(sim) | |||||
return sum(scores) / len(scores) | |||||
elif ood_score == 'baseline': | |||||
outputs, outputs_aux = model(x, penultimate=True) | |||||
scores = F.softmax(outputs, dim=1).max(dim=1)[0] | |||||
return scores | |||||
elif ood_score == 'baseline_marginalized': | |||||
total_outputs = 0 | |||||
for i in range(4): | |||||
x_rot = torch.rot90(x, i, (2, 3)) | |||||
outputs, outputs_aux = model(x_rot, penultimate=True, joint=True) | |||||
total_outputs += outputs_aux['joint'][:, P.n_classes * i:P.n_classes * (i + 1)] | |||||
scores = F.softmax(total_outputs / 4., dim=1).max(dim=1)[0] | |||||
return scores | |||||
else: | |||||
raise NotImplementedError() | |||||
def get_features(model, simclr_aug, x, layer='simclr', sample_num=1): | |||||
model.eval() | |||||
feats = [] | |||||
for seed in range(sample_num): | |||||
set_random_seed(seed) | |||||
x_t = simclr_aug(x) | |||||
with torch.no_grad(): | |||||
_, output_aux = model(x_t, penultimate=True, simclr=True, shift=True) | |||||
feats.append(output_aux[layer]) | |||||
return feats |
import os | |||||
from copy import deepcopy | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
import numpy as np | |||||
import models.transform_layers as TL | |||||
from utils.utils import set_random_seed, normalize | |||||
from evals.evals import get_auroc | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
hflip = TL.HorizontalFlipLayer().to(device) | |||||
def eval_ood_detection(P, model, id_loader, ood_loaders, ood_scores, train_loader=None, simclr_aug=None): | |||||
auroc_dict = dict() | |||||
for ood in ood_loaders.keys(): | |||||
auroc_dict[ood] = dict() | |||||
assert len(ood_scores) == 1 # assume single ood_score for simplicity | |||||
ood_score = ood_scores[0] | |||||
base_path = os.path.split(P.load_path)[0] # checkpoint directory | |||||
prefix = f'{P.ood_samples}' | |||||
if P.resize_fix: | |||||
prefix += f'_resize_fix_{P.resize_factor}' | |||||
else: | |||||
prefix += f'_resize_range_{P.resize_factor}' | |||||
prefix = os.path.join(base_path, f'feats_{prefix}') | |||||
kwargs = { | |||||
'simclr_aug': simclr_aug, | |||||
'sample_num': P.ood_samples, | |||||
'layers': P.ood_layer, | |||||
} | |||||
print('Pre-compute global statistics...') | |||||
feats_train = get_features(P, f'{P.dataset}_train', model, train_loader, prefix=prefix, **kwargs) # (M, T, d) | |||||
P.axis = [] | |||||
for f in feats_train['simclr'].chunk(P.K_shift, dim=1): | |||||
axis = f.mean(dim=1) # (M, d) | |||||
P.axis.append(normalize(axis, dim=1).to(device)) | |||||
print('axis size: ' + ' '.join(map(lambda x: str(len(x)), P.axis))) | |||||
f_sim = [f.mean(dim=1) for f in feats_train['simclr'].chunk(P.K_shift, dim=1)] # list of (M, d) | |||||
f_shi = [f.mean(dim=1) for f in feats_train['shift'].chunk(P.K_shift, dim=1)] # list of (M, 4) | |||||
weight_sim = [] | |||||
weight_shi = [] | |||||
for shi in range(P.K_shift): | |||||
sim_norm = f_sim[shi].norm(dim=1) # (M) | |||||
shi_mean = f_shi[shi][:, shi] # (M) | |||||
weight_sim.append(1 / sim_norm.mean().item()) | |||||
weight_shi.append(1 / shi_mean.mean().item()) | |||||
if ood_score == 'simclr': | |||||
P.weight_sim = [1] | |||||
P.weight_shi = [0] | |||||
elif ood_score == 'CSI': | |||||
P.weight_sim = weight_sim | |||||
P.weight_shi = weight_shi | |||||
else: | |||||
raise ValueError() | |||||
print(f'weight_sim:\t' + '\t'.join(map('{:.4f}'.format, P.weight_sim))) | |||||
print(f'weight_shi:\t' + '\t'.join(map('{:.4f}'.format, P.weight_shi))) | |||||
print('Pre-compute features...') | |||||
feats_id = get_features(P, P.dataset, model, id_loader, prefix=prefix, **kwargs) # (N, T, d) | |||||
feats_ood = dict() | |||||
for ood, ood_loader in ood_loaders.items(): | |||||
if ood == 'interp': | |||||
feats_ood[ood] = get_features(P, ood, model, id_loader, interp=True, prefix=prefix, **kwargs) | |||||
else: | |||||
feats_ood[ood] = get_features(P, ood, model, ood_loader, prefix=prefix, **kwargs) | |||||
print(f'Compute OOD scores... (score: {ood_score})') | |||||
scores_id = get_scores(P, feats_id, ood_score).numpy() | |||||
scores_ood = dict() | |||||
if P.one_class_idx is not None: | |||||
one_class_score = [] | |||||
for ood, feats in feats_ood.items(): | |||||
scores_ood[ood] = get_scores(P, feats, ood_score).numpy() | |||||
auroc_dict[ood][ood_score] = get_auroc(scores_id, scores_ood[ood]) | |||||
if P.one_class_idx is not None: | |||||
one_class_score.append(scores_ood[ood]) | |||||
if P.one_class_idx is not None: | |||||
one_class_score = np.concatenate(one_class_score) | |||||
one_class_total = get_auroc(scores_id, one_class_score) | |||||
print(f'One_class_real_mean: {one_class_total}') | |||||
if P.print_score: | |||||
print_score(P.dataset, scores_id) | |||||
for ood, scores in scores_ood.items(): | |||||
print_score(ood, scores) | |||||
return auroc_dict | |||||
def get_scores(P, feats_dict, ood_score): | |||||
# convert to gpu tensor | |||||
feats_sim = feats_dict['simclr'].to(device) | |||||
feats_shi = feats_dict['shift'].to(device) | |||||
N = feats_sim.size(0) | |||||
# compute scores | |||||
scores = [] | |||||
for f_sim, f_shi in zip(feats_sim, feats_shi): | |||||
f_sim = [f.mean(dim=0, keepdim=True) for f in f_sim.chunk(P.K_shift)] # list of (1, d) | |||||
f_shi = [f.mean(dim=0, keepdim=True) for f in f_shi.chunk(P.K_shift)] # list of (1, 4) | |||||
score = 0 | |||||
for shi in range(P.K_shift): | |||||
score += (f_sim[shi] * P.axis[shi]).sum(dim=1).max().item() * P.weight_sim[shi] | |||||
score += f_shi[shi][:, shi].item() * P.weight_shi[shi] | |||||
score = score / P.K_shift | |||||
scores.append(score) | |||||
scores = torch.tensor(scores) | |||||
assert scores.dim() == 1 and scores.size(0) == N # (N) | |||||
return scores.cpu() | |||||
def get_features(P, data_name, model, loader, interp=False, prefix='', | |||||
simclr_aug=None, sample_num=1, layers=('simclr', 'shift')): | |||||
if not isinstance(layers, (list, tuple)): | |||||
layers = [layers] | |||||
# load pre-computed features if exists | |||||
feats_dict = dict() | |||||
# for layer in layers: | |||||
# path = prefix + f'_{data_name}_{layer}.pth' | |||||
# if os.path.exists(path): | |||||
# feats_dict[layer] = torch.load(path) | |||||
# pre-compute features and save to the path | |||||
left = [layer for layer in layers if layer not in feats_dict.keys()] | |||||
if len(left) > 0: | |||||
_feats_dict = _get_features(P, model, loader, interp, (P.dataset == 'imagenet' or | |||||
P.dataset == 'CNMC' or | |||||
P.dataset == 'CNMC_grayscale'), | |||||
simclr_aug, sample_num, layers=left) | |||||
for layer, feats in _feats_dict.items(): | |||||
path = prefix + f'_{data_name}_{layer}.pth' | |||||
torch.save(_feats_dict[layer], path) | |||||
feats_dict[layer] = feats # update value | |||||
return feats_dict | |||||
def _get_features(P, model, loader, interp=False, imagenet=False, simclr_aug=None, | |||||
sample_num=1, layers=('simclr', 'shift')): | |||||
if not isinstance(layers, (list, tuple)): | |||||
layers = [layers] | |||||
# check if arguments are valid | |||||
assert simclr_aug is not None | |||||
if imagenet is True: # assume batch_size = 1 for ImageNet | |||||
sample_num = 1 | |||||
# compute features in full dataset | |||||
model.eval() | |||||
feats_all = {layer: [] for layer in layers} # initialize: empty list | |||||
for i, (x, _) in enumerate(loader): | |||||
if interp: | |||||
x_interp = (x + last) / 2 if i > 0 else x # omit the first batch, assume batch sizes are equal | |||||
last = x # save the last batch | |||||
x = x_interp # use interp as current batch | |||||
if imagenet is True: | |||||
x = torch.cat(x[0], dim=0) # augmented list of x | |||||
x = x.to(device) # gpu tensor | |||||
# compute features in one batch | |||||
feats_batch = {layer: [] for layer in layers} # initialize: empty list | |||||
for seed in range(sample_num): | |||||
set_random_seed(seed) | |||||
if P.K_shift > 1: | |||||
x_t = torch.cat([P.shift_trans(hflip(x), k) for k in range(P.K_shift)]) | |||||
else: | |||||
x_t = x # No shifting: SimCLR | |||||
x_t = simclr_aug(x_t) | |||||
# compute augmented features | |||||
with torch.no_grad(): | |||||
kwargs = {layer: True for layer in layers} # only forward selected layers | |||||
_, output_aux = model(x_t, **kwargs) | |||||
# add features in one batch | |||||
for layer in layers: | |||||
feats = output_aux[layer].cpu() | |||||
if imagenet is False: | |||||
feats_batch[layer] += feats.chunk(P.K_shift) | |||||
else: | |||||
feats_batch[layer] += [feats] # (B, d) cpu tensor | |||||
# concatenate features in one batch | |||||
for key, val in feats_batch.items(): | |||||
if imagenet: | |||||
feats_batch[key] = torch.stack(val, dim=0) # (B, T, d) | |||||
else: | |||||
feats_batch[key] = torch.stack(val, dim=1) # (B, T, d) | |||||
# add features in full dataset | |||||
for layer in layers: | |||||
feats_all[layer] += [feats_batch[layer]] | |||||
# concatenate features in full dataset | |||||
for key, val in feats_all.items(): | |||||
feats_all[key] = torch.cat(val, dim=0) # (N, T, d) | |||||
# reshape order | |||||
if imagenet is False: | |||||
# Convert [1,2,3,4, 1,2,3,4] -> [1,1, 2,2, 3,3, 4,4] | |||||
for key, val in feats_all.items(): | |||||
N, T, d = val.size() # T = K * T' | |||||
val = val.view(N, -1, P.K_shift, d) # (N, T', K, d) | |||||
val = val.transpose(2, 1) # (N, 4, T', d) | |||||
val = val.reshape(N, T, d) # (N, T, d) | |||||
feats_all[key] = val | |||||
return feats_all | |||||
def print_score(data_name, scores): | |||||
quantile = np.quantile(scores, np.arange(0, 1.1, 0.1)) | |||||
print('{:18s} '.format(data_name) + | |||||
'{:.4f} +- {:.4f} '.format(np.mean(scores), np.std(scores)) + | |||||
' '.join(['q{:d}: {:.4f}'.format(i * 10, quantile[i]) for i in range(11)])) | |||||
from sys import argv | |||||
from os import system | |||||
from datasets.prepare_data import prep, resize | |||||
import torch | |||||
import os | |||||
from datasets.postprocess_data import postprocess_data | |||||
DATA_BASE_DIR = r'/home/feoktistovar67431/CSI/CSI_local/main.py' | |||||
BASE_DIR = '/home/feoktistovar67431/CSI/CSI_local/' | |||||
def main(): | |||||
for argument in argv: | |||||
if argument == '--proc_step': | |||||
proc_step = argv[argv.index(argument)+1] | |||||
if proc_step == 'eval': | |||||
system("eval.py "+' '.join(argv[1:])) | |||||
if proc_step == 'train': | |||||
system(BASE_DIR + os.sep + "eval.py " + ' '.join(argv[1:])) | |||||
if proc_step == 'plot': | |||||
plot_data() | |||||
elif proc_step == 'post_proc': | |||||
postprocess_data( | |||||
[ | |||||
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm4_one_class_0\log.txt', | |||||
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm4_one_class_0_64px\log.txt', | |||||
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm16_one_class_0_32px\log.txt', | |||||
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm_one_class_0_64px_batch64\log.txt', | |||||
r'\CNMC_resnet18_unsup_simclr_CSI_shift_rotation_one_class_0\log.txt', | |||||
r"\CNMC_resnet18_unsup_simclr_CSI_shift_gauss_one_class_0_32px\log.txt" | |||||
# r'\cifar10_resnet18_unsup_simclr_CSI_shift_rotation_one_class_1\log.txt' | |||||
] | |||||
) | |||||
if __name__ == '__main__': | |||||
main() |
from abc import * | |||||
import torch.nn as nn | |||||
class BaseModel(nn.Module, metaclass=ABCMeta): | |||||
def __init__(self, last_dim, num_classes=10, simclr_dim=128): | |||||
super(BaseModel, self).__init__() | |||||
self.linear = nn.Linear(last_dim, num_classes) | |||||
self.simclr_layer = nn.Sequential( | |||||
nn.Linear(last_dim, last_dim), | |||||
nn.ReLU(), | |||||
nn.Linear(last_dim, simclr_dim), | |||||
) | |||||
self.shift_cls_layer = nn.Linear(last_dim, 2) | |||||
self.joint_distribution_layer = nn.Linear(last_dim, 4 * num_classes) | |||||
@abstractmethod | |||||
def penultimate(self, inputs, all_features=False): | |||||
pass | |||||
def forward(self, inputs, penultimate=False, simclr=False, shift=False, joint=False): | |||||
_aux = {} | |||||
_return_aux = False | |||||
features = self.penultimate(inputs) | |||||
output = self.linear(features) | |||||
if penultimate: | |||||
_return_aux = True | |||||
_aux['penultimate'] = features | |||||
if simclr: | |||||
_return_aux = True | |||||
_aux['simclr'] = self.simclr_layer(features) | |||||
if shift: | |||||
_return_aux = True | |||||
_aux['shift'] = self.shift_cls_layer(features) | |||||
if joint: | |||||
_return_aux = True | |||||
_aux['joint'] = self.joint_distribution_layer(features) | |||||
if _return_aux: | |||||
return output, _aux | |||||
return output |
import torch.nn as nn | |||||
from models.resnet import ResNet18, ResNet34, ResNet50 | |||||
from models.resnet_imagenet import resnet18, resnet50 | |||||
import models.transform_layers as TL | |||||
from torchvision import transforms | |||||
def get_simclr_augmentation(P, image_size): | |||||
""" | |||||
Creates positive data for training. | |||||
:param P: parsed arguments | |||||
:param image_size: size of image | |||||
:return: transformation | |||||
""" | |||||
# parameter for resizecrop | |||||
resize_scale = (P.resize_factor, 1.0) # resize scaling factor | |||||
if P.resize_fix: # if resize_fix is True, use same scale | |||||
resize_scale = (P.resize_factor, P.resize_factor) | |||||
# Align augmentation | |||||
s = P.color_distort | |||||
color_jitter = TL.ColorJitterLayer(brightness=s*0.8, contrast=s*0.8, saturation=s*0.8, hue=s*0.2, p=0.8) | |||||
color_gray = TL.RandomColorGrayLayer(p=0.2) | |||||
resize_crop = TL.RandomResizedCropLayer(scale=resize_scale, size=(image_size[0], image_size[1])) | |||||
#v_flip = transforms.RandomVerticalFlip() | |||||
#h_flip = transforms.RandomHorizontalFlip() | |||||
rand_aff = transforms.RandomAffine(degrees=360, translate=(0.2, 0.2)) | |||||
# Transform define # | |||||
if P.dataset == 'imagenet': # Using RandomResizedCrop at PIL transform | |||||
transform = nn.Sequential( | |||||
color_jitter, | |||||
color_gray, | |||||
) | |||||
elif P.dataset == 'CNMC': | |||||
transform = nn.Sequential( | |||||
color_jitter, | |||||
color_gray, | |||||
resize_crop, | |||||
) | |||||
else: | |||||
transform = nn.Sequential( | |||||
color_jitter, | |||||
color_gray, | |||||
resize_crop, | |||||
) | |||||
return transform | |||||
def get_shift_module(P, eval=False): | |||||
""" | |||||
Creates shift transformation (negative). | |||||
:param P: parsed arguments | |||||
:param eval: whether it is an evaluation step or not | |||||
:return: transformation | |||||
""" | |||||
if P.shift_trans_type == 'rotation': | |||||
shift_transform = TL.Rotation() | |||||
K_shift = 4 | |||||
elif P.shift_trans_type == 'cutperm': | |||||
shift_transform = TL.CutPerm() | |||||
K_shift = 4 | |||||
elif P.shift_trans_type == 'noise': | |||||
shift_transform = TL.GaussNoise(mean=P.noise_mean, std=P.noise_std) | |||||
K_shift = 4 | |||||
elif P.shift_trans_type == 'randpers': | |||||
shift_transform = TL.RandPers(distortion_scale=P.distortion_scale, p=1) | |||||
K_shift = 4 | |||||
elif P.shift_trans_type == 'sharp': | |||||
shift_transform = TL.RandomAdjustSharpness(sharpness_factor=P.sharpness_factor, p=1) | |||||
K_shift = 4 | |||||
elif P.shift_trans_type == 'blur': | |||||
kernel_size = int(int(P.res.replace('px', ''))*0.1) | |||||
if kernel_size%2 == 0: | |||||
kernel_size+=1 | |||||
sigma = (0.1, float(P.blur_sigma)) | |||||
shift_transform = TL.GaussBlur(kernel_size=kernel_size, sigma=sigma) | |||||
K_shift = 4 | |||||
elif P.shift_trans_type == 'blur_randpers': | |||||
kernel_size = int(P.res.replace('px', '')) * 0.1 | |||||
sigma = (0.1, float(P.blur_sigma)) | |||||
shift_transform = TL.BlurRandpers(kernel_size=kernel_size, sigma=sigma, distortion_scale=P.distortion_scale, p=1) | |||||
K_shift = 4 | |||||
elif P.shift_trans_type == 'blur_sharp': | |||||
kernel_size = int(P.res.replace('px', '')) * 0.1 | |||||
sigma = (0.1, float(P.blur_sigma)) | |||||
shift_transform = TL.BlurSharpness(kernel_size=kernel_size, sigma=sigma, sharpness_factor=P.sharpness_factor, p=1) | |||||
K_shift = 4 | |||||
elif P.shift_trans_type == 'randpers_sharp': | |||||
shift_transform = TL.RandpersSharpness(distortion_scale=P.distortion_scale, p=1, sharpness_factor=P.sharpness_factor) | |||||
K_shift = 4 | |||||
elif P.shift_trans_type == 'blur_randpers_sharp': | |||||
kernel_size = int(P.res.replace('px', '')) * 0.1 | |||||
sigma = (0.1, float(P.blur_sigma)) | |||||
shift_transform = TL.BlurRandpersSharpness(kernel_size=kernel_size, sigma=sigma, distortion_scale=P.distortion_scale, p=1, sharpness_factor=P.sharpness_factor) | |||||
K_shift = 4 | |||||
else: | |||||
shift_transform = nn.Identity() | |||||
K_shift = 1 | |||||
if not eval and not ('sup' in P.mode): | |||||
assert P.batch_size == int(128/K_shift) | |||||
return shift_transform, K_shift | |||||
def get_shift_classifer(model, K_shift): | |||||
model.shift_cls_layer = nn.Linear(model.last_dim, K_shift) | |||||
return model | |||||
def get_classifier(mode, n_classes=10): | |||||
if mode == 'resnet18': | |||||
classifier = ResNet18(num_classes=n_classes) | |||||
elif mode == 'resnet34': | |||||
classifier = ResNet34(num_classes=n_classes) | |||||
elif mode == 'resnet50': | |||||
classifier = ResNet50(num_classes=n_classes) | |||||
elif mode == 'resnet18_imagenet': | |||||
classifier = resnet18(num_classes=n_classes) | |||||
elif mode == 'resnet50_imagenet': | |||||
classifier = resnet50(num_classes=n_classes) | |||||
else: | |||||
raise NotImplementedError() | |||||
return classifier | |||||
'''ResNet in PyTorch. | |||||
BasicBlock and Bottleneck module is from the original ResNet paper: | |||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun | |||||
Deep Residual Learning for Image Recognition. arXiv:1512.03385 | |||||
PreActBlock and PreActBottleneck module is from the later paper: | |||||
[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun | |||||
Identity Mappings in Deep Residual Networks. arXiv:1603.05027 | |||||
''' | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from models.base_model import BaseModel | |||||
from models.transform_layers import NormalizeLayer | |||||
from torch.nn.utils import spectral_norm | |||||
def conv3x3(in_planes, out_planes, stride=1): | |||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) | |||||
class BasicBlock(nn.Module): | |||||
expansion = 1 | |||||
def __init__(self, in_planes, planes, stride=1): | |||||
super(BasicBlock, self).__init__() | |||||
self.conv1 = conv3x3(in_planes, planes, stride) | |||||
self.conv2 = conv3x3(planes, planes) | |||||
self.bn1 = nn.BatchNorm2d(planes) | |||||
self.bn2 = nn.BatchNorm2d(planes) | |||||
self.shortcut = nn.Sequential() | |||||
if stride != 1 or in_planes != self.expansion*planes: | |||||
self.shortcut = nn.Sequential( | |||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), | |||||
nn.BatchNorm2d(self.expansion*planes) | |||||
) | |||||
def forward(self, x): | |||||
out = F.relu(self.bn1(self.conv1(x))) | |||||
out = self.bn2(self.conv2(out)) | |||||
out += self.shortcut(x) | |||||
out = F.relu(out) | |||||
return out | |||||
class PreActBlock(nn.Module): | |||||
'''Pre-activation version of the BasicBlock.''' | |||||
expansion = 1 | |||||
def __init__(self, in_planes, planes, stride=1): | |||||
super(PreActBlock, self).__init__() | |||||
self.conv1 = conv3x3(in_planes, planes, stride) | |||||
self.conv2 = conv3x3(planes, planes) | |||||
self.bn1 = nn.BatchNorm2d(in_planes) | |||||
self.bn2 = nn.BatchNorm2d(planes) | |||||
self.shortcut = nn.Sequential() | |||||
if stride != 1 or in_planes != self.expansion*planes: | |||||
self.shortcut = nn.Sequential( | |||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) | |||||
) | |||||
def forward(self, x): | |||||
out = F.relu(self.bn1(x)) | |||||
shortcut = self.shortcut(out) | |||||
out = self.conv1(out) | |||||
out = self.conv2(F.relu(self.bn2(out))) | |||||
out += shortcut | |||||
return out | |||||
class Bottleneck(nn.Module): | |||||
expansion = 4 | |||||
def __init__(self, in_planes, planes, stride=1): | |||||
super(Bottleneck, self).__init__() | |||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) | |||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | |||||
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) | |||||
self.bn1 = nn.BatchNorm2d(planes) | |||||
self.bn2 = nn.BatchNorm2d(planes) | |||||
self.bn3 = nn.BatchNorm2d(self.expansion * planes) | |||||
self.shortcut = nn.Sequential() | |||||
if stride != 1 or in_planes != self.expansion*planes: | |||||
self.shortcut = nn.Sequential( | |||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), | |||||
nn.BatchNorm2d(self.expansion*planes) | |||||
) | |||||
def forward(self, x): | |||||
out = F.relu(self.bn1(self.conv1(x))) | |||||
out = F.relu(self.bn2(self.conv2(out))) | |||||
out = self.bn3(self.conv3(out)) | |||||
out += self.shortcut(x) | |||||
out = F.relu(out) | |||||
return out | |||||
class PreActBottleneck(nn.Module): | |||||
'''Pre-activation version of the original Bottleneck module.''' | |||||
expansion = 4 | |||||
def __init__(self, in_planes, planes, stride=1): | |||||
super(PreActBottleneck, self).__init__() | |||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) | |||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | |||||
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) | |||||
self.bn1 = nn.BatchNorm2d(in_planes) | |||||
self.bn2 = nn.BatchNorm2d(planes) | |||||
self.bn3 = nn.BatchNorm2d(planes) | |||||
self.shortcut = nn.Sequential() | |||||
if stride != 1 or in_planes != self.expansion*planes: | |||||
self.shortcut = nn.Sequential( | |||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) | |||||
) | |||||
def forward(self, x): | |||||
out = F.relu(self.bn1(x)) | |||||
shortcut = self.shortcut(out) | |||||
out = self.conv1(out) | |||||
out = self.conv2(F.relu(self.bn2(out))) | |||||
out = self.conv3(F.relu(self.bn3(out))) | |||||
out += shortcut | |||||
return out | |||||
class ResNet(BaseModel): | |||||
def __init__(self, block, num_blocks, num_classes=10): | |||||
last_dim = 512 * block.expansion | |||||
super(ResNet, self).__init__(last_dim, num_classes) | |||||
self.in_planes = 64 | |||||
self.last_dim = last_dim | |||||
self.normalize = NormalizeLayer() | |||||
self.conv1 = conv3x3(3, 64) | |||||
self.bn1 = nn.BatchNorm2d(64) | |||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) | |||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) | |||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) | |||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) | |||||
def _make_layer(self, block, planes, num_blocks, stride): | |||||
strides = [stride] + [1]*(num_blocks-1) | |||||
layers = [] | |||||
for stride in strides: | |||||
layers.append(block(self.in_planes, planes, stride)) | |||||
self.in_planes = planes * block.expansion | |||||
return nn.Sequential(*layers) | |||||
def penultimate(self, x, all_features=False): | |||||
out_list = [] | |||||
out = self.normalize(x) | |||||
out = self.conv1(out) | |||||
out = self.bn1(out) | |||||
out = F.relu(out) | |||||
out_list.append(out) | |||||
out = self.layer1(out) | |||||
out_list.append(out) | |||||
out = self.layer2(out) | |||||
out_list.append(out) | |||||
out = self.layer3(out) | |||||
out_list.append(out) | |||||
out = self.layer4(out) | |||||
out_list.append(out) | |||||
out = F.avg_pool2d(out, 4) | |||||
out = out.view(out.size(0), -1) | |||||
if all_features: | |||||
return out, out_list | |||||
else: | |||||
return out | |||||
def ResNet18(num_classes): | |||||
return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes) | |||||
def ResNet34(num_classes): | |||||
return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes) | |||||
def ResNet50(num_classes): | |||||
return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes) |
import torch | |||||
import torch.nn as nn | |||||
from models.base_model import BaseModel | |||||
from models.transform_layers import NormalizeLayer | |||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |||||
"""3x3 convolution with padding""" | |||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |||||
padding=dilation, groups=groups, bias=False, dilation=dilation) | |||||
def conv1x1(in_planes, out_planes, stride=1): | |||||
"""1x1 convolution""" | |||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||||
class BasicBlock(nn.Module): | |||||
expansion = 1 | |||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | |||||
base_width=64, dilation=1, norm_layer=None): | |||||
super(BasicBlock, self).__init__() | |||||
if norm_layer is None: | |||||
norm_layer = nn.BatchNorm2d | |||||
if groups != 1 or base_width != 64: | |||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64') | |||||
if dilation > 1: | |||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | |||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |||||
self.conv1 = conv3x3(inplanes, planes, stride) | |||||
self.bn1 = norm_layer(planes) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.conv2 = conv3x3(planes, planes) | |||||
self.bn2 = norm_layer(planes) | |||||
self.downsample = downsample | |||||
self.stride = stride | |||||
def forward(self, x): | |||||
identity = x | |||||
out = self.conv1(x) | |||||
out = self.bn1(out) | |||||
out = self.relu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn2(out) | |||||
if self.downsample is not None: | |||||
identity = self.downsample(x) | |||||
out += identity | |||||
out = self.relu(out) | |||||
return out | |||||
class Bottleneck(nn.Module): | |||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) | |||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1) | |||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. | |||||
# This variant is also known as ResNet V1.5 and improves accuracy according to | |||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. | |||||
expansion = 4 | |||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | |||||
base_width=64, dilation=1, norm_layer=None): | |||||
super(Bottleneck, self).__init__() | |||||
if norm_layer is None: | |||||
norm_layer = nn.BatchNorm2d | |||||
width = int(planes * (base_width / 64.)) * groups | |||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |||||
self.conv1 = conv1x1(inplanes, width) | |||||
self.bn1 = norm_layer(width) | |||||
self.conv2 = conv3x3(width, width, stride, groups, dilation) | |||||
self.bn2 = norm_layer(width) | |||||
self.conv3 = conv1x1(width, planes * self.expansion) | |||||
self.bn3 = norm_layer(planes * self.expansion) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.downsample = downsample | |||||
self.stride = stride | |||||
def forward(self, x): | |||||
identity = x | |||||
out = self.conv1(x) | |||||
out = self.bn1(out) | |||||
out = self.relu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn2(out) | |||||
out = self.relu(out) | |||||
out = self.conv3(out) | |||||
out = self.bn3(out) | |||||
if self.downsample is not None: | |||||
identity = self.downsample(x) | |||||
out += identity | |||||
out = self.relu(out) | |||||
return out | |||||
class ResNet(BaseModel): | |||||
def __init__(self, block, layers, num_classes=10, | |||||
zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, | |||||
norm_layer=None): | |||||
last_dim = 512 * block.expansion | |||||
super(ResNet, self).__init__(last_dim, num_classes) | |||||
if norm_layer is None: | |||||
norm_layer = nn.BatchNorm2d | |||||
self._norm_layer = norm_layer | |||||
self.inplanes = 64 | |||||
self.dilation = 1 | |||||
if replace_stride_with_dilation is None: | |||||
# each element in the tuple indicates if we should replace | |||||
# the 2x2 stride with a dilated convolution instead | |||||
replace_stride_with_dilation = [False, False, False] | |||||
if len(replace_stride_with_dilation) != 3: | |||||
raise ValueError("replace_stride_with_dilation should be None " | |||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)) | |||||
self.groups = groups | |||||
self.base_width = width_per_group | |||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, | |||||
bias=False) | |||||
self.bn1 = norm_layer(self.inplanes) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||||
self.layer1 = self._make_layer(block, 64, layers[0]) | |||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, | |||||
dilate=replace_stride_with_dilation[0]) | |||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, | |||||
dilate=replace_stride_with_dilation[1]) | |||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, | |||||
dilate=replace_stride_with_dilation[2]) | |||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |||||
self.normalize = NormalizeLayer() | |||||
self.last_dim = 512 * block.expansion | |||||
for m in self.modules(): | |||||
if isinstance(m, nn.Conv2d): | |||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | |||||
nn.init.constant_(m.weight, 1) | |||||
nn.init.constant_(m.bias, 0) | |||||
# Zero-initialize the last BN in each residual branch, | |||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity. | |||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | |||||
if zero_init_residual: | |||||
for m in self.modules(): | |||||
if isinstance(m, Bottleneck): | |||||
nn.init.constant_(m.bn3.weight, 0) | |||||
elif isinstance(m, BasicBlock): | |||||
nn.init.constant_(m.bn2.weight, 0) | |||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | |||||
norm_layer = self._norm_layer | |||||
downsample = None | |||||
previous_dilation = self.dilation | |||||
if dilate: | |||||
self.dilation *= stride | |||||
stride = 1 | |||||
if stride != 1 or self.inplanes != planes * block.expansion: | |||||
downsample = nn.Sequential( | |||||
conv1x1(self.inplanes, planes * block.expansion, stride), | |||||
norm_layer(planes * block.expansion), | |||||
) | |||||
layers = [] | |||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups, | |||||
self.base_width, previous_dilation, norm_layer)) | |||||
self.inplanes = planes * block.expansion | |||||
for _ in range(1, blocks): | |||||
layers.append(block(self.inplanes, planes, groups=self.groups, | |||||
base_width=self.base_width, dilation=self.dilation, | |||||
norm_layer=norm_layer)) | |||||
return nn.Sequential(*layers) | |||||
def penultimate(self, x, all_features=False): | |||||
# See note [TorchScript super()] | |||||
out_list = [] | |||||
x = self.normalize(x) | |||||
x = self.conv1(x) | |||||
x = self.bn1(x) | |||||
x = self.relu(x) | |||||
x = self.maxpool(x) | |||||
out_list.append(x) | |||||
x = self.layer1(x) | |||||
out_list.append(x) | |||||
x = self.layer2(x) | |||||
out_list.append(x) | |||||
x = self.layer3(x) | |||||
out_list.append(x) | |||||
x = self.layer4(x) | |||||
out_list.append(x) | |||||
x = self.avgpool(x) | |||||
x = torch.flatten(x, 1) | |||||
if all_features: | |||||
return x, out_list | |||||
else: | |||||
return x | |||||
def _resnet(arch, block, layers, **kwargs): | |||||
model = ResNet(block, layers, **kwargs) | |||||
return model | |||||
def resnet18(**kwargs): | |||||
r"""ResNet-18 model from | |||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | |||||
""" | |||||
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], **kwargs) | |||||
def resnet50(**kwargs): | |||||
r"""ResNet-50 model from | |||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | |||||
""" | |||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], **kwargs) |
import math | |||||
import numbers | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from torch.autograd import Function | |||||
from torchvision import transforms | |||||
if torch.__version__ >= '1.4.0': | |||||
kwargs = {'align_corners': False} | |||||
else: | |||||
kwargs = {} | |||||
def rgb2hsv(rgb): | |||||
"""Convert a 4-d RGB tensor to the HSV counterpart. | |||||
Here, we compute hue using atan2() based on the definition in [1], | |||||
instead of using the common lookup table approach as in [2, 3]. | |||||
Those values agree when the angle is a multiple of 30°, | |||||
otherwise they may differ at most ~1.2°. | |||||
References | |||||
[1] https://en.wikipedia.org/wiki/Hue | |||||
[2] https://www.rapidtables.com/convert/color/rgb-to-hsv.html | |||||
[3] https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L212 | |||||
""" | |||||
r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :] | |||||
Cmax = rgb.max(1)[0] | |||||
Cmin = rgb.min(1)[0] | |||||
delta = Cmax - Cmin | |||||
hue = torch.atan2(math.sqrt(3) * (g - b), 2 * r - g - b) | |||||
hue = (hue % (2 * math.pi)) / (2 * math.pi) | |||||
saturate = delta / Cmax | |||||
value = Cmax | |||||
hsv = torch.stack([hue, saturate, value], dim=1) | |||||
hsv[~torch.isfinite(hsv)] = 0. | |||||
return hsv | |||||
def hsv2rgb(hsv): | |||||
"""Convert a 4-d HSV tensor to the RGB counterpart. | |||||
>>> %timeit hsv2rgb(hsv) | |||||
2.37 ms ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |||||
>>> %timeit rgb2hsv_fast(rgb) | |||||
298 µs ± 542 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) | |||||
>>> torch.allclose(hsv2rgb(hsv), hsv2rgb_fast(hsv), atol=1e-6) | |||||
True | |||||
References | |||||
[1] https://en.wikipedia.org/wiki/HSL_and_HSV#HSV_to_RGB_alternative | |||||
""" | |||||
h, s, v = hsv[:, [0]], hsv[:, [1]], hsv[:, [2]] | |||||
c = v * s | |||||
n = hsv.new_tensor([5, 3, 1]).view(3, 1, 1) | |||||
k = (n + h * 6) % 6 | |||||
t = torch.min(k, 4 - k) | |||||
t = torch.clamp(t, 0, 1) | |||||
return v - c * t | |||||
class RandomResizedCropLayer(nn.Module): | |||||
def __init__(self, size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.)): | |||||
''' | |||||
Inception Crop | |||||
size (tuple): size of fowarding image (C, W, H) | |||||
scale (tuple): range of size of the origin size cropped | |||||
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped | |||||
''' | |||||
super(RandomResizedCropLayer, self).__init__() | |||||
_eye = torch.eye(2, 3) | |||||
self.size = size | |||||
self.register_buffer('_eye', _eye) | |||||
self.scale = scale | |||||
self.ratio = ratio | |||||
def forward(self, inputs, whbias=None): | |||||
_device = inputs.device | |||||
N = inputs.size(0) | |||||
_theta = self._eye.repeat(N, 1, 1) | |||||
if whbias is None: | |||||
whbias = self._sample_latent(inputs) | |||||
_theta[:, 0, 0] = whbias[:, 0] | |||||
_theta[:, 1, 1] = whbias[:, 1] | |||||
_theta[:, 0, 2] = whbias[:, 2] | |||||
_theta[:, 1, 2] = whbias[:, 3] | |||||
grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device) | |||||
output = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs) | |||||
if self.size is not None: | |||||
output = F.adaptive_avg_pool2d(output, self.size) | |||||
# output = F.adaptive_avg_pool2d(output, self.size) | |||||
# output = F.adaptive_avg_pool2d(output, (self.size[0], self.size[1])) | |||||
return output | |||||
def _clamp(self, whbias): | |||||
w = whbias[:, 0] | |||||
h = whbias[:, 1] | |||||
w_bias = whbias[:, 2] | |||||
h_bias = whbias[:, 3] | |||||
# Clamp with scale | |||||
w = torch.clamp(w, *self.scale) | |||||
h = torch.clamp(h, *self.scale) | |||||
# Clamp with ratio | |||||
w = self.ratio[0] * h + torch.relu(w - self.ratio[0] * h) | |||||
w = self.ratio[1] * h - torch.relu(self.ratio[1] * h - w) | |||||
# Clamp with bias range: w_bias \in (w - 1, 1 - w), h_bias \in (h - 1, 1 - h) | |||||
w_bias = w - 1 + torch.relu(w_bias - w + 1) | |||||
w_bias = 1 - w - torch.relu(1 - w - w_bias) | |||||
h_bias = h - 1 + torch.relu(h_bias - h + 1) | |||||
h_bias = 1 - h - torch.relu(1 - h - h_bias) | |||||
whbias = torch.stack([w, h, w_bias, h_bias], dim=0).t() | |||||
return whbias | |||||
def _sample_latent(self, inputs): | |||||
_device = inputs.device | |||||
N, _, width, height = inputs.shape | |||||
# N * 10 trial | |||||
area = width * height | |||||
target_area = np.random.uniform(*self.scale, N * 10) * area | |||||
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) | |||||
aspect_ratio = np.exp(np.random.uniform(*log_ratio, N * 10)) | |||||
# If doesn't satisfy ratio condition, then do central crop | |||||
w = np.round(np.sqrt(target_area * aspect_ratio)) | |||||
h = np.round(np.sqrt(target_area / aspect_ratio)) | |||||
cond = (0 < w) * (w <= width) * (0 < h) * (h <= height) | |||||
w = w[cond] | |||||
h = h[cond] | |||||
cond_len = w.shape[0] | |||||
if cond_len >= N: | |||||
w = w[:N] | |||||
h = h[:N] | |||||
else: | |||||
w = np.concatenate([w, np.ones(N - cond_len) * width]) | |||||
h = np.concatenate([h, np.ones(N - cond_len) * height]) | |||||
w_bias = np.random.randint(w - width, width - w + 1) / width | |||||
h_bias = np.random.randint(h - height, height - h + 1) / height | |||||
w = w / width | |||||
h = h / height | |||||
whbias = np.column_stack([w, h, w_bias, h_bias]) | |||||
whbias = torch.tensor(whbias, device=_device) | |||||
return whbias | |||||
class HorizontalFlipRandomCrop(nn.Module): | |||||
def __init__(self, max_range): | |||||
super(HorizontalFlipRandomCrop, self).__init__() | |||||
self.max_range = max_range | |||||
_eye = torch.eye(2, 3) | |||||
self.register_buffer('_eye', _eye) | |||||
def forward(self, input, sign=None, bias=None, rotation=None): | |||||
_device = input.device | |||||
N = input.size(0) | |||||
_theta = self._eye.repeat(N, 1, 1) | |||||
if sign is None: | |||||
sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1 | |||||
if bias is None: | |||||
bias = torch.empty((N, 2), device=_device).uniform_(-self.max_range, self.max_range) | |||||
_theta[:, 0, 0] = sign | |||||
_theta[:, :, 2] = bias | |||||
if rotation is not None: | |||||
_theta[:, 0:2, 0:2] = rotation | |||||
grid = F.affine_grid(_theta, input.size(), **kwargs).to(_device) | |||||
output = F.grid_sample(input, grid, padding_mode='reflection', **kwargs) | |||||
return output | |||||
def _sample_latent(self, N, device=None): | |||||
sign = torch.bernoulli(torch.ones(N, device=device) * 0.5) * 2 - 1 | |||||
bias = torch.empty((N, 2), device=device).uniform_(-self.max_range, self.max_range) | |||||
return sign, bias | |||||
class Rotation(nn.Module): | |||||
def __init__(self, max_range = 4): | |||||
super(Rotation, self).__init__() | |||||
self.max_range = max_range | |||||
self.prob = 0.5 | |||||
def forward(self, input, aug_index=None): | |||||
_device = input.device | |||||
_, _, H, W = input.size() | |||||
if aug_index is None: | |||||
aug_index = np.random.randint(4) | |||||
output = torch.rot90(input, aug_index, (2, 3)) | |||||
_prob = input.new_full((input.size(0),), self.prob) | |||||
_mask = torch.bernoulli(_prob).view(-1, 1, 1, 1) | |||||
output = _mask * input + (1-_mask) * output | |||||
else: | |||||
aug_index = aug_index % self.max_range | |||||
output = torch.rot90(input, aug_index, (2, 3)) | |||||
return output | |||||
class RandomAdjustSharpness(nn.Module): | |||||
def __init__(self, sharpness_factor=0.5, p=0.5): | |||||
super(RandomAdjustSharpness, self).__init__() | |||||
self.sharpness_factor = sharpness_factor | |||||
self.prob = p | |||||
def forward(self, input, aug_index=None): | |||||
_device = input.device | |||||
_, _, H, W = input.size() | |||||
if aug_index == 0: | |||||
output = input | |||||
else: | |||||
output = transforms.RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.prob)(input) | |||||
return output | |||||
class RandPers(nn.Module): | |||||
def __init__(self, distortion_scale=0.5, p=0.5): | |||||
super(RandPers, self).__init__() | |||||
self.distortion_scale = distortion_scale | |||||
self.prob = p | |||||
def forward(self, input, aug_index=None): | |||||
_device = input.device | |||||
_, _, H, W = input.size() | |||||
if aug_index == 0: | |||||
output = input | |||||
else: | |||||
output = transforms.RandomPerspective(distortion_scale=self.distortion_scale, p=self.prob)(input) | |||||
return output | |||||
class GaussBlur(nn.Module): | |||||
def __init__(self, max_range = 4, kernel_size=3, sigma=(0.1, 2.0)): | |||||
super(GaussBlur, self).__init__() | |||||
self.max_range = max_range | |||||
self.prob = 0.5 | |||||
self.sigma = sigma | |||||
self.kernel_size = kernel_size | |||||
def forward(self, input, aug_index=None): | |||||
_device = input.device | |||||
_, _, H, W = input.size() | |||||
if aug_index is None: | |||||
aug_index = np.random.randint(4) | |||||
output = transforms.GaussianBlur(kernel_size=13, sigma=abs(aug_index)+1)(input) | |||||
_prob = input.new_full((input.size(0),), self.prob) | |||||
_mask = torch.bernoulli(_prob).view(-1, 1, 1, 1) | |||||
output = _mask * input + (1-_mask) * output | |||||
else: | |||||
if aug_index == 0: | |||||
output = input | |||||
else: | |||||
output = transforms.GaussianBlur(kernel_size=self.kernel_size, sigma=self.sigma)(input) | |||||
return output | |||||
class GaussNoise(nn.Module): | |||||
def __init__(self, mean = 0, std = 1): | |||||
super(GaussNoise, self).__init__() | |||||
self.mean = mean | |||||
self.std = std | |||||
def forward(self, input, aug_index=None): | |||||
_device = input.device | |||||
_, _, H, W = input.size() | |||||
if aug_index == 0: | |||||
output = input | |||||
else: | |||||
output = input + (torch.randn(input.size()) * self.std + self.mean).to(_device) | |||||
return output | |||||
class BlurRandpers(nn.Module): | |||||
def __init__(self, max_range=2, kernel_size=3, sigma=(10, 20), distortion_scale=0.6, p=1): | |||||
super(BlurRandpers, self).__init__() | |||||
self.max_range = max_range | |||||
self.sigma = sigma | |||||
self.kernel_size = kernel_size | |||||
self.distortion_scale = distortion_scale | |||||
self.p = p | |||||
self.gauss = GaussBlur(kernel_size=self.kernel_size, sigma=self.sigma) | |||||
self.randpers = RandPers(distortion_scale=self.distortion_scale, p=self.p) | |||||
def forward(self, input, aug_index=None): | |||||
output = self.gauss.forward(input=input, aug_index=aug_index) | |||||
output = self.randpers.forward(input=output, aug_index=aug_index) | |||||
return output | |||||
class BlurSharpness(nn.Module): | |||||
def __init__(self, max_range=2, kernel_size=3, sigma=(10, 20), sharpness_factor=0.6, p=1): | |||||
super(BlurSharpness, self).__init__() | |||||
self.max_range = max_range | |||||
self.sigma = sigma | |||||
self.kernel_size = kernel_size | |||||
self.sharpness_factor = sharpness_factor | |||||
self.p = p | |||||
self.gauss = GaussBlur(kernel_size=self.kernel_size, sigma=self.sigma) | |||||
self.sharp = RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.p) | |||||
def forward(self, input, aug_index=None): | |||||
output = self.gauss.forward(input=input, aug_index=aug_index) | |||||
output = self.sharp.forward(input=output, aug_index=aug_index) | |||||
return output | |||||
class RandpersSharpness(nn.Module): | |||||
def __init__(self, max_range=2, distortion_scale=0.6, p=1, sharpness_factor=0.6): | |||||
super(RandpersSharpness, self).__init__() | |||||
self.max_range = max_range | |||||
self.distortion_scale = distortion_scale | |||||
self.p = p | |||||
self.sharpness_factor = sharpness_factor | |||||
self.randpers = RandPers(distortion_scale=self.distortion_scale, p=self.p) | |||||
self.sharp = RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.p) | |||||
def forward(self, input, aug_index=None): | |||||
output = self.randpers.forward(input=input, aug_index=aug_index) | |||||
output = self.sharp.forward(input=output, aug_index=aug_index) | |||||
return output | |||||
class BlurRandpersSharpness(nn.Module): | |||||
def __init__(self, max_range=2, kernel_size=3, sigma=(10, 20), distortion_scale=0.6, p=1, sharpness_factor=0.6): | |||||
super(BlurRandpersSharpness, self).__init__() | |||||
self.max_range = max_range | |||||
self.sigma = sigma | |||||
self.kernel_size = kernel_size | |||||
self.distortion_scale = distortion_scale | |||||
self.p = p | |||||
self.sharpness_factor = sharpness_factor | |||||
self.gauss = GaussBlur(kernel_size=self.kernel_size, sigma=self.sigma) | |||||
self.randpers = RandPers(distortion_scale=self.distortion_scale, p=self.p) | |||||
self.sharp = RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.p) | |||||
def forward(self, input, aug_index=None): | |||||
output = self.gauss.forward(input=input, aug_index=aug_index) | |||||
output = self.randpers.forward(input=output, aug_index=aug_index) | |||||
output = self.sharp.forward(input=output, aug_index=aug_index) | |||||
return output | |||||
class FourCrop(nn.Module): | |||||
def __init__(self, max_range = 4): | |||||
super(FourCrop, self).__init__() | |||||
self.max_range = max_range | |||||
self.prob = 0.5 | |||||
def forward(self, inputs): | |||||
outputs = inputs | |||||
for i in range(8): | |||||
outputs[i] = self._crop(inputs.size(), inputs[i], i) | |||||
return outputs | |||||
def _crop(self, size, input, i): | |||||
_, _, H, W = size | |||||
h_mid = int(H / 2) | |||||
w_mid = int(W / 2) | |||||
if i == 0 or i == 4: | |||||
corner = input[:, 0:h_mid, 0:w_mid] | |||||
elif i == 1 or i == 5: | |||||
corner = input[:, 0:h_mid, w_mid:] | |||||
elif i == 2 or i == 6: | |||||
corner = input[:, h_mid:, 0:w_mid] | |||||
elif i == 3 or i == 7: | |||||
corner = input[:, h_mid:, w_mid:] | |||||
else: | |||||
corner = input | |||||
corner = transforms.Resize(size=2*h_mid)(corner) | |||||
return corner | |||||
class CutPerm(nn.Module): | |||||
def __init__(self, max_range = 4): | |||||
super(CutPerm, self).__init__() | |||||
self.max_range = max_range | |||||
self.prob = 0.5 | |||||
def forward(self, input, aug_index=None): | |||||
_device = input.device | |||||
_, _, H, W = input.size() | |||||
if aug_index is None: | |||||
aug_index = np.random.randint(4) | |||||
output = self._cutperm(input, aug_index) | |||||
_prob = input.new_full((input.size(0),), self.prob) | |||||
_mask = torch.bernoulli(_prob).view(-1, 1, 1, 1) | |||||
output = _mask * input + (1 - _mask) * output | |||||
else: | |||||
aug_index = aug_index % self.max_range | |||||
output = self._cutperm(input, aug_index) | |||||
return output | |||||
def _cutperm(self, inputs, aug_index): | |||||
_, _, H, W = inputs.size() | |||||
h_mid = int(H / 2) | |||||
w_mid = int(W / 2) | |||||
jigsaw_h = aug_index // 2 | |||||
jigsaw_v = aug_index % 2 | |||||
if jigsaw_h == 1: | |||||
inputs = torch.cat((inputs[:, :, h_mid:, :], inputs[:, :, 0:h_mid, :]), dim=2) | |||||
if jigsaw_v == 1: | |||||
inputs = torch.cat((inputs[:, :, :, w_mid:], inputs[:, :, :, 0:w_mid]), dim=3) | |||||
return inputs | |||||
def assemble(a, b, c, d): | |||||
ab = torch.cat((a, b), dim=2) | |||||
cd = torch.cat((c, d), dim=2) | |||||
output = torch.cat((ab, cd), dim=3) | |||||
return output | |||||
def quarter(inputs): | |||||
_, _, H, W = inputs.size() | |||||
h_mid = int(H / 2) | |||||
w_mid = int(W / 2) | |||||
quarters = [] | |||||
quarters.append(inputs[:, :, 0:h_mid, 0:w_mid]) | |||||
quarters.append(inputs[:, :, 0:h_mid, w_mid:]) | |||||
quarters.append(inputs[:, :, h_mid:, 0:w_mid]) | |||||
quarters.append(inputs[:, :, h_mid:, w_mid:]) | |||||
return quarters | |||||
class HorizontalFlipLayer(nn.Module): | |||||
def __init__(self): | |||||
""" | |||||
img_size : (int, int, int) | |||||
Height and width must be powers of 2. E.g. (32, 32, 1) or | |||||
(64, 128, 3). Last number indicates number of channels, e.g. 1 for | |||||
grayscale or 3 for RGB | |||||
""" | |||||
super(HorizontalFlipLayer, self).__init__() | |||||
_eye = torch.eye(2, 3) | |||||
self.register_buffer('_eye', _eye) | |||||
def forward(self, inputs): | |||||
_device = inputs.device | |||||
N = inputs.size(0) | |||||
_theta = self._eye.repeat(N, 1, 1) | |||||
r_sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1 | |||||
_theta[:, 0, 0] = r_sign | |||||
grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device) | |||||
inputs = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs) | |||||
return inputs | |||||
class RandomColorGrayLayer(nn.Module): | |||||
def __init__(self, p): | |||||
super(RandomColorGrayLayer, self).__init__() | |||||
self.prob = p | |||||
_weight = torch.tensor([[0.299, 0.587, 0.114]]) | |||||
self.register_buffer('_weight', _weight.view(1, 3, 1, 1)) | |||||
def forward(self, inputs, aug_index=None): | |||||
if aug_index == 0: | |||||
return inputs | |||||
l = F.conv2d(inputs, self._weight) | |||||
gray = torch.cat([l, l, l], dim=1) | |||||
if aug_index is None: | |||||
_prob = inputs.new_full((inputs.size(0),), self.prob) | |||||
_mask = torch.bernoulli(_prob).view(-1, 1, 1, 1) | |||||
gray = inputs * (1 - _mask) + gray * _mask | |||||
return gray | |||||
class ColorJitterLayer(nn.Module): | |||||
def __init__(self, p, brightness, contrast, saturation, hue): | |||||
super(ColorJitterLayer, self).__init__() | |||||
self.prob = p | |||||
self.brightness = self._check_input(brightness, 'brightness') | |||||
self.contrast = self._check_input(contrast, 'contrast') | |||||
self.saturation = self._check_input(saturation, 'saturation') | |||||
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), | |||||
clip_first_on_zero=False) | |||||
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): | |||||
if isinstance(value, numbers.Number): | |||||
if value < 0: | |||||
raise ValueError("If {} is a single number, it must be non negative.".format(name)) | |||||
value = [center - value, center + value] | |||||
if clip_first_on_zero: | |||||
value[0] = max(value[0], 0) | |||||
elif isinstance(value, (tuple, list)) and len(value) == 2: | |||||
if not bound[0] <= value[0] <= value[1] <= bound[1]: | |||||
raise ValueError("{} values should be between {}".format(name, bound)) | |||||
else: | |||||
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) | |||||
# if value is 0 or (1., 1.) for brightness/contrast/saturation | |||||
# or (0., 0.) for hue, do nothing | |||||
if value[0] == value[1] == center: | |||||
value = None | |||||
return value | |||||
def adjust_contrast(self, x): | |||||
if self.contrast: | |||||
factor = x.new_empty(x.size(0), 1, 1, 1).uniform_(*self.contrast) | |||||
means = torch.mean(x, dim=[2, 3], keepdim=True) | |||||
x = (x - means) * factor + means | |||||
return torch.clamp(x, 0, 1) | |||||
def adjust_hsv(self, x): | |||||
f_h = x.new_zeros(x.size(0), 1, 1) | |||||
f_s = x.new_ones(x.size(0), 1, 1) | |||||
f_v = x.new_ones(x.size(0), 1, 1) | |||||
if self.hue: | |||||
f_h.uniform_(*self.hue) | |||||
if self.saturation: | |||||
f_s = f_s.uniform_(*self.saturation) | |||||
if self.brightness: | |||||
f_v = f_v.uniform_(*self.brightness) | |||||
return RandomHSVFunction.apply(x, f_h, f_s, f_v) | |||||
def transform(self, inputs): | |||||
# Shuffle transform | |||||
if np.random.rand() > 0.5: | |||||
transforms = [self.adjust_contrast, self.adjust_hsv] | |||||
else: | |||||
transforms = [self.adjust_hsv, self.adjust_contrast] | |||||
for t in transforms: | |||||
inputs = t(inputs) | |||||
return inputs | |||||
def forward(self, inputs): | |||||
_prob = inputs.new_full((inputs.size(0),), self.prob) | |||||
_mask = torch.bernoulli(_prob).view(-1, 1, 1, 1) | |||||
return inputs * (1 - _mask) + self.transform(inputs) * _mask | |||||
class RandomHSVFunction(Function): | |||||
@staticmethod | |||||
def forward(ctx, x, f_h, f_s, f_v): | |||||
# ctx is a context object that can be used to stash information | |||||
# for backward computation | |||||
x = rgb2hsv(x) | |||||
h = x[:, 0, :, :] | |||||
h += (f_h * 255. / 360.) | |||||
h = (h % 1) | |||||
x[:, 0, :, :] = h | |||||
x[:, 1, :, :] = x[:, 1, :, :] * f_s | |||||
x[:, 2, :, :] = x[:, 2, :, :] * f_v | |||||
x = torch.clamp(x, 0, 1) | |||||
x = hsv2rgb(x) | |||||
return x | |||||
@staticmethod | |||||
def backward(ctx, grad_output): | |||||
# We return as many input gradients as there were arguments. | |||||
# Gradients of non-Tensor arguments to forward must be None. | |||||
grad_input = None | |||||
if ctx.needs_input_grad[0]: | |||||
grad_input = grad_output.clone() | |||||
return grad_input, None, None, None | |||||
class NormalizeLayer(nn.Module): | |||||
""" | |||||
In order to certify radii in original coordinates rather than standardized coordinates, we | |||||
add the Gaussian noise _before_ standardizing, which is why we have standardization be the first | |||||
layer of the classifier rather than as a part of preprocessing as is typical. | |||||
""" | |||||
def __init__(self): | |||||
super(NormalizeLayer, self).__init__() | |||||
def forward(self, inputs): | |||||
return (inputs - 0.5) / 0.5 | |||||
from utils.utils import Logger | |||||
from utils.utils import save_checkpoint | |||||
from utils.utils import save_linear_checkpoint | |||||
from common.train import * | |||||
from evals import test_classifier | |||||
if 'sup' in P.mode: | |||||
from training.sup import setup | |||||
else: | |||||
from training.unsup import setup | |||||
train, fname = setup(P.mode, P) | |||||
logger = Logger(fname, ask=not resume, local_rank=P.local_rank) | |||||
logger.log(P) | |||||
logger.log(model) | |||||
if P.multi_gpu: | |||||
linear = model.module.linear | |||||
else: | |||||
linear = model.linear | |||||
linear_optim = torch.optim.Adam(linear.parameters(), lr=1e-3, betas=(.9, .999), weight_decay=P.weight_decay) | |||||
# Run experiments | |||||
for epoch in range(start_epoch, P.epochs + 1): | |||||
logger.log_dirname(f"Epoch {epoch}") | |||||
model.train() | |||||
if P.multi_gpu: | |||||
train_sampler.set_epoch(epoch) | |||||
kwargs = {} | |||||
kwargs['linear'] = linear | |||||
kwargs['linear_optim'] = linear_optim | |||||
kwargs['simclr_aug'] = simclr_aug | |||||
train(P, epoch, model, criterion, optimizer, scheduler_warmup, train_loader, logger=logger, **kwargs) | |||||
model.eval() | |||||
if epoch % P.save_step == 0 and P.local_rank == 0: | |||||
if P.multi_gpu: | |||||
save_states = model.module.state_dict() | |||||
else: | |||||
save_states = model.state_dict() | |||||
save_checkpoint(epoch, save_states, optimizer.state_dict(), logger.logdir) | |||||
save_linear_checkpoint(linear_optim.state_dict(), logger.logdir) | |||||
if epoch % P.error_step == 0 and ('sup' in P.mode): | |||||
error = test_classifier(P, model, test_loader, epoch, logger=logger) | |||||
is_best = (best > error) | |||||
if is_best: | |||||
best = error | |||||
logger.scalar_summary('eval/best_error', best, epoch) | |||||
logger.log('[Epoch %3d] [Test %5.2f] [Best %5.2f]' % (epoch, error, best)) |
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
def update_learning_rate(P, optimizer, cur_epoch, n, n_total): | |||||
cur_epoch = cur_epoch - 1 | |||||
lr = P.lr_init | |||||
if P.optimizer == 'sgd' or 'lars': | |||||
DECAY_RATIO = 0.1 | |||||
elif P.optimizer == 'adam': | |||||
DECAY_RATIO = 0.3 | |||||
else: | |||||
raise NotImplementedError() | |||||
if P.warmup > 0: | |||||
cur_iter = cur_epoch * n_total + n | |||||
if cur_iter <= P.warmup: | |||||
lr *= cur_iter / float(P.warmup) | |||||
if cur_epoch >= 0.5 * P.epochs: | |||||
lr *= DECAY_RATIO | |||||
if cur_epoch >= 0.75 * P.epochs: | |||||
lr *= DECAY_RATIO | |||||
for param_group in optimizer.param_groups: | |||||
param_group['lr'] = lr | |||||
return lr | |||||
def _cross_entropy(input, targets, reduction='mean'): | |||||
targets_prob = F.softmax(targets, dim=1) | |||||
xent = (-targets_prob * F.log_softmax(input, dim=1)).sum(1) | |||||
if reduction == 'sum': | |||||
return xent.sum() | |||||
elif reduction == 'mean': | |||||
return xent.mean() | |||||
elif reduction == 'none': | |||||
return xent | |||||
else: | |||||
raise NotImplementedError() | |||||
def _entropy(input, reduction='mean'): | |||||
return _cross_entropy(input, input, reduction) | |||||
def cross_entropy_soft(input, targets, reduction='mean'): | |||||
targets_prob = F.softmax(targets, dim=1) | |||||
xent = (-targets_prob * F.log_softmax(input, dim=1)).sum(1) | |||||
if reduction == 'sum': | |||||
return xent.sum() | |||||
elif reduction == 'mean': | |||||
return xent.mean() | |||||
elif reduction == 'none': | |||||
return xent | |||||
else: | |||||
raise NotImplementedError() | |||||
def kl_div(input, targets, reduction='batchmean'): | |||||
return F.kl_div(F.log_softmax(input, dim=1), F.softmax(targets, dim=1), | |||||
reduction=reduction) | |||||
def target_nll_loss(inputs, targets, reduction='none'): | |||||
inputs_t = -F.nll_loss(inputs, targets, reduction='none') | |||||
logit_diff = inputs - inputs_t.view(-1, 1) | |||||
logit_diff = logit_diff.scatter(1, targets.view(-1, 1), -1e8) | |||||
diff_max = logit_diff.max(1)[0] | |||||
if reduction == 'sum': | |||||
return diff_max.sum() | |||||
elif reduction == 'mean': | |||||
return diff_max.mean() | |||||
elif reduction == 'none': | |||||
return diff_max | |||||
else: | |||||
raise NotImplementedError() | |||||
def target_nll_c(inputs, targets, reduction='none'): | |||||
conf = torch.softmax(inputs, dim=1) | |||||
conf_t = -F.nll_loss(conf, targets, reduction='none') | |||||
conf_diff = conf - conf_t.view(-1, 1) | |||||
conf_diff = conf_diff.scatter(1, targets.view(-1, 1), -1) | |||||
diff_max = conf_diff.max(1)[0] | |||||
if reduction == 'sum': | |||||
return diff_max.sum() | |||||
elif reduction == 'mean': | |||||
return diff_max.mean() | |||||
elif reduction == 'none': | |||||
return diff_max | |||||
else: | |||||
raise NotImplementedError() |
import torch | |||||
import torch.distributed as dist | |||||
import diffdist.functional as distops | |||||
def get_similarity_matrix(outputs, chunk=2, multi_gpu=False): | |||||
''' | |||||
Compute similarity matrix | |||||
- outputs: (B', d) tensor for B' = B * chunk | |||||
- sim_matrix: (B', B') tensor | |||||
''' | |||||
if multi_gpu: | |||||
outputs_gathered = [] | |||||
for out in outputs.chunk(chunk): | |||||
gather_t = [torch.empty_like(out) for _ in range(dist.get_world_size())] | |||||
gather_t = torch.cat(distops.all_gather(gather_t, out)) | |||||
outputs_gathered.append(gather_t) | |||||
outputs = torch.cat(outputs_gathered) | |||||
sim_matrix = torch.mm(outputs, outputs.t()) # (B', d), (d, B') -> (B', B') | |||||
return sim_matrix | |||||
def NT_xent(sim_matrix, temperature=0.5, chunk=2, eps=1e-8): | |||||
''' | |||||
Compute NT_xent loss | |||||
- sim_matrix: (B', B') tensor for B' = B * chunk (first 2B are pos samples) | |||||
''' | |||||
device = sim_matrix.device | |||||
B = sim_matrix.size(0) // chunk # B = B' / chunk | |||||
eye = torch.eye(B * chunk).to(device) # (B', B') | |||||
sim_matrix = torch.exp(sim_matrix / temperature) * (1 - eye) # remove diagonal | |||||
denom = torch.sum(sim_matrix, dim=1, keepdim=True) | |||||
sim_matrix = -torch.log(sim_matrix / (denom + eps) + eps) # loss matrix | |||||
loss = torch.sum(sim_matrix[:B, B:].diag() + sim_matrix[B:, :B].diag()) / (2 * B) | |||||
return loss | |||||
def Supervised_NT_xent(sim_matrix, labels, temperature=0.5, chunk=2, eps=1e-8, multi_gpu=False): | |||||
''' | |||||
Compute NT_xent loss | |||||
- sim_matrix: (B', B') tensor for B' = B * chunk (first 2B are pos samples) | |||||
''' | |||||
device = sim_matrix.device | |||||
if multi_gpu: | |||||
gather_t = [torch.empty_like(labels) for _ in range(dist.get_world_size())] | |||||
labels = torch.cat(distops.all_gather(gather_t, labels)) | |||||
labels = labels.repeat(2) | |||||
logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True) | |||||
sim_matrix = sim_matrix - logits_max.detach() | |||||
B = sim_matrix.size(0) // chunk # B = B' / chunk | |||||
eye = torch.eye(B * chunk).to(device) # (B', B') | |||||
sim_matrix = torch.exp(sim_matrix / temperature) * (1 - eye) # remove diagonal | |||||
denom = torch.sum(sim_matrix, dim=1, keepdim=True) | |||||
sim_matrix = -torch.log(sim_matrix / (denom + eps) + eps) # loss matrix | |||||
labels = labels.contiguous().view(-1, 1) | |||||
Mask = torch.eq(labels, labels.t()).float().to(device) | |||||
#Mask = eye * torch.stack([labels == labels[i] for i in range(labels.size(0))]).float().to(device) | |||||
Mask = Mask / (Mask.sum(dim=1, keepdim=True) + eps) | |||||
loss = torch.sum(Mask * sim_matrix) / (2 * B) | |||||
return loss | |||||
from torch.optim.lr_scheduler import _LRScheduler | |||||
from torch.optim.lr_scheduler import ReduceLROnPlateau | |||||
class GradualWarmupScheduler(_LRScheduler): | |||||
""" Gradually warm-up(increasing) learning rate in optimizer. | |||||
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. | |||||
Args: | |||||
optimizer (Optimizer): Wrapped optimizer. | |||||
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. | |||||
total_epoch: target learning rate is reached at total_epoch, gradually | |||||
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) | |||||
""" | |||||
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): | |||||
self.multiplier = multiplier | |||||
if self.multiplier < 1.: | |||||
raise ValueError('multiplier should be greater thant or equal to 1.') | |||||
self.total_epoch = total_epoch | |||||
self.after_scheduler = after_scheduler | |||||
self.finished = False | |||||
super(GradualWarmupScheduler, self).__init__(optimizer) | |||||
def get_lr(self): | |||||
if self.last_epoch > self.total_epoch: | |||||
if self.after_scheduler: | |||||
if not self.finished: | |||||
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] | |||||
self.finished = True | |||||
return self.after_scheduler.get_lr() | |||||
return [base_lr * self.multiplier for base_lr in self.base_lrs] | |||||
if self.multiplier == 1.0: | |||||
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] | |||||
else: | |||||
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] | |||||
def step_ReduceLROnPlateau(self, metrics, epoch=None): | |||||
if epoch is None: | |||||
epoch = self.last_epoch + 1 | |||||
self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning | |||||
if self.last_epoch <= self.total_epoch: | |||||
warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] | |||||
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): | |||||
param_group['lr'] = lr | |||||
else: | |||||
if epoch is None: | |||||
self.after_scheduler.step(metrics, None) | |||||
else: | |||||
self.after_scheduler.step(metrics, epoch - self.total_epoch) | |||||
def step(self, epoch=None, metrics=None): | |||||
if type(self.after_scheduler) != ReduceLROnPlateau: | |||||
if self.finished and self.after_scheduler: | |||||
if epoch is None: | |||||
self.after_scheduler.step(None) | |||||
else: | |||||
self.after_scheduler.step(epoch - self.total_epoch) | |||||
else: | |||||
return super(GradualWarmupScheduler, self).step(epoch) | |||||
else: | |||||
self.step_ReduceLROnPlateau(metrics, epoch) |
def setup(mode, P): | |||||
fname = f'{P.dataset}_{P.model}_{mode}_{P.res}' | |||||
if mode == 'sup_linear': | |||||
from .sup_linear import train | |||||
elif mode == 'sup_CSI_linear': | |||||
from .sup_CSI_linear import train | |||||
elif mode == 'sup_simclr': | |||||
from .sup_simclr import train | |||||
elif mode == 'sup_simclr_CSI': | |||||
assert P.batch_size == 32 | |||||
# currently only support rotation | |||||
from .sup_simclr_CSI import train | |||||
else: | |||||
raise NotImplementedError() | |||||
if P.suffix is not None: | |||||
fname += f'_{P.suffix}' | |||||
return train, fname | |||||
def update_comp_loss(loss_dict, loss_in, loss_out, loss_diff, batch_size): | |||||
loss_dict['pos'].update(loss_in, batch_size) | |||||
loss_dict['neg'].update(loss_out, batch_size) | |||||
loss_dict['diff'].update(loss_diff, batch_size) | |||||
def summary_comp_loss(logger, tag, loss_dict, epoch): | |||||
logger.scalar_summary(f'{tag}/pos', loss_dict['pos'].average, epoch) | |||||
logger.scalar_summary(f'{tag}/neg', loss_dict['neg'].average, epoch) | |||||
logger.scalar_summary(f'{tag}', loss_dict['diff'].average, epoch) | |||||
import time | |||||
import torch.optim | |||||
import torch.optim.lr_scheduler as lr_scheduler | |||||
import models.transform_layers as TL | |||||
from utils.utils import AverageMeter, normalize | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
hflip = TL.HorizontalFlipLayer().to(device) | |||||
def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None, | |||||
simclr_aug=None, linear=None, linear_optim=None): | |||||
if P.multi_gpu: | |||||
rotation_linear = model.module.shift_cls_layer | |||||
joint_linear = model.module.joint_distribution_layer | |||||
else: | |||||
rotation_linear = model.shift_cls_layer | |||||
joint_linear = model.joint_distribution_layer | |||||
if epoch == 1: | |||||
# define optimizer and save in P (argument) | |||||
milestones = [int(0.6 * P.epochs), int(0.75 * P.epochs), int(0.9 * P.epochs)] | |||||
linear_optim = torch.optim.SGD(linear.parameters(), | |||||
lr=1e-1, weight_decay=P.weight_decay) | |||||
P.linear_optim = linear_optim | |||||
P.linear_scheduler = lr_scheduler.MultiStepLR(P.linear_optim, gamma=0.1, milestones=milestones) | |||||
rotation_linear_optim = torch.optim.SGD(rotation_linear.parameters(), | |||||
lr=1e-1, weight_decay=P.weight_decay) | |||||
P.rotation_linear_optim = rotation_linear_optim | |||||
P.rot_scheduler = lr_scheduler.MultiStepLR(P.rotation_linear_optim, gamma=0.1, milestones=milestones) | |||||
joint_linear_optim = torch.optim.SGD(joint_linear.parameters(), | |||||
lr=1e-1, weight_decay=P.weight_decay) | |||||
P.joint_linear_optim = joint_linear_optim | |||||
P.joint_scheduler = lr_scheduler.MultiStepLR(P.joint_linear_optim, gamma=0.1, milestones=milestones) | |||||
if logger is None: | |||||
log_ = print | |||||
else: | |||||
log_ = logger.log | |||||
batch_time = AverageMeter() | |||||
data_time = AverageMeter() | |||||
losses = dict() | |||||
losses['cls'] = AverageMeter() | |||||
losses['rot'] = AverageMeter() | |||||
check = time.time() | |||||
for n, (images, labels) in enumerate(loader): | |||||
model.eval() | |||||
count = n * P.n_gpus # number of trained samples | |||||
data_time.update(time.time() - check) | |||||
check = time.time() | |||||
### SimCLR loss ### | |||||
if P.dataset != 'imagenet': | |||||
batch_size = images.size(0) | |||||
images = images.to(device) | |||||
images = hflip(images) # 2B with hflip | |||||
else: | |||||
batch_size = images[0].size(0) | |||||
images = images[0].to(device) | |||||
labels = labels.to(device) | |||||
images = torch.cat([torch.rot90(images, rot, (2, 3)) for rot in range(4)]) # 4B | |||||
rot_labels = torch.cat([torch.ones_like(labels) * k for k in range(4)], 0) # B -> 4B | |||||
joint_labels = torch.cat([labels + P.n_classes * i for i in range(4)], dim=0) | |||||
images = simclr_aug(images) # simclr augmentation | |||||
_, outputs_aux = model(images, penultimate=True) | |||||
penultimate = outputs_aux['penultimate'].detach() | |||||
outputs = linear(penultimate[0:batch_size]) # only use 0 degree samples for linear eval | |||||
outputs_rot = rotation_linear(penultimate) | |||||
outputs_joint = joint_linear(penultimate) | |||||
loss_ce = criterion(outputs, labels) | |||||
loss_rot = criterion(outputs_rot, rot_labels) | |||||
loss_joint = criterion(outputs_joint, joint_labels) | |||||
### CE loss ### | |||||
P.linear_optim.zero_grad() | |||||
loss_ce.backward() | |||||
P.linear_optim.step() | |||||
### Rot loss ### | |||||
P.rotation_linear_optim.zero_grad() | |||||
loss_rot.backward() | |||||
P.rotation_linear_optim.step() | |||||
### Joint loss ### | |||||
P.joint_linear_optim.zero_grad() | |||||
loss_joint.backward() | |||||
P.joint_linear_optim.step() | |||||
### optimizer learning rate ### | |||||
lr = P.linear_optim.param_groups[0]['lr'] | |||||
batch_time.update(time.time() - check) | |||||
### Log losses ### | |||||
losses['cls'].update(loss_ce.item(), batch_size) | |||||
losses['rot'].update(loss_rot.item(), batch_size) | |||||
if count % 50 == 0: | |||||
log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n' | |||||
'[LossC %f] [LossR %f]' % | |||||
(epoch, count, batch_time.value, data_time.value, lr, | |||||
losses['cls'].value, losses['rot'].value)) | |||||
check = time.time() | |||||
P.linear_scheduler.step() | |||||
P.rot_scheduler.step() | |||||
P.joint_scheduler.step() | |||||
log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f] [LossR %f]' % | |||||
(batch_time.average, data_time.average, | |||||
losses['cls'].average, losses['rot'].average)) | |||||
if logger is not None: | |||||
logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch) | |||||
logger.scalar_summary('train/loss_rot', losses['rot'].average, epoch) | |||||
logger.scalar_summary('train/batch_time', batch_time.average, epoch) |
import time | |||||
import torch.optim | |||||
import torch.optim.lr_scheduler as lr_scheduler | |||||
import models.transform_layers as TL | |||||
from utils.utils import AverageMeter, normalize | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
hflip = TL.HorizontalFlipLayer().to(device) | |||||
def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None, | |||||
simclr_aug=None, linear=None, linear_optim=None): | |||||
if epoch == 1: | |||||
# define optimizer and save in P (argument) | |||||
milestones = [int(0.6 * P.epochs), int(0.75 * P.epochs), int(0.9 * P.epochs)] | |||||
linear_optim = torch.optim.SGD(linear.parameters(), | |||||
lr=1e-1, weight_decay=P.weight_decay) | |||||
P.linear_optim = linear_optim | |||||
P.linear_scheduler = lr_scheduler.MultiStepLR(P.linear_optim, gamma=0.1, milestones=milestones) | |||||
if logger is None: | |||||
log_ = print | |||||
else: | |||||
log_ = logger.log | |||||
batch_time = AverageMeter() | |||||
data_time = AverageMeter() | |||||
losses = dict() | |||||
losses['cls'] = AverageMeter() | |||||
check = time.time() | |||||
for n, (images, labels) in enumerate(loader): | |||||
model.eval() | |||||
count = n * P.n_gpus # number of trained samples | |||||
data_time.update(time.time() - check) | |||||
check = time.time() | |||||
### SimCLR loss ### | |||||
if P.dataset != 'imagenet': | |||||
batch_size = images.size(0) | |||||
images = images.to(device) | |||||
images = hflip(images) # 2B with hflip | |||||
else: | |||||
batch_size = images[0].size(0) | |||||
images = images[0].to(device) | |||||
labels = labels.to(device) | |||||
images = simclr_aug(images) # simclr augmentation | |||||
_, outputs_aux = model(images, penultimate=True) | |||||
penultimate = outputs_aux['penultimate'].detach() | |||||
outputs = linear(penultimate[0:batch_size]) # only use 0 degree samples for linear eval | |||||
loss_ce = criterion(outputs, labels) | |||||
### CE loss ### | |||||
P.linear_optim.zero_grad() | |||||
loss_ce.backward() | |||||
P.linear_optim.step() | |||||
### optimizer learning rate ### | |||||
lr = P.linear_optim.param_groups[0]['lr'] | |||||
batch_time.update(time.time() - check) | |||||
### Log losses ### | |||||
losses['cls'].update(loss_ce.item(), batch_size) | |||||
if count % 50 == 0: | |||||
log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n' | |||||
'[LossC %f]' % | |||||
(epoch, count, batch_time.value, data_time.value, lr, | |||||
losses['cls'].value, )) | |||||
check = time.time() | |||||
P.linear_scheduler.step() | |||||
log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f]' % | |||||
(batch_time.average, data_time.average, | |||||
losses['cls'].average)) | |||||
if logger is not None: | |||||
logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch) | |||||
logger.scalar_summary('train/batch_time', batch_time.average, epoch) |
import time | |||||
import torch.optim | |||||
import models.transform_layers as TL | |||||
from training.contrastive_loss import get_similarity_matrix, Supervised_NT_xent | |||||
from utils.utils import AverageMeter, normalize | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
hflip = TL.HorizontalFlipLayer().to(device) | |||||
def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None, | |||||
simclr_aug=None, linear=None, linear_optim=None): | |||||
assert simclr_aug is not None | |||||
assert P.sim_lambda == 1.0 | |||||
if logger is None: | |||||
log_ = print | |||||
else: | |||||
log_ = logger.log | |||||
batch_time = AverageMeter() | |||||
data_time = AverageMeter() | |||||
losses = dict() | |||||
losses['cls'] = AverageMeter() | |||||
losses['sim'] = AverageMeter() | |||||
losses['simnorm'] = AverageMeter() | |||||
check = time.time() | |||||
for n, (images, labels) in enumerate(loader): | |||||
model.train() | |||||
count = n * P.n_gpus # number of trained samples | |||||
data_time.update(time.time() - check) | |||||
check = time.time() | |||||
### SimCLR loss ### | |||||
if P.dataset != 'imagenet' and P.dataset != 'CNMC' and P.dataset != 'CNMC_grayscale': | |||||
batch_size = images.size(0) | |||||
images = images.to(device) | |||||
images_pair = hflip(images.repeat(2, 1, 1, 1)) # 2B with hflip | |||||
else: | |||||
batch_size = images[0].size(0) | |||||
images1, images2 = images[0].to(device), images[1].to(device) | |||||
images_pair = torch.cat([images1, images2], dim=0) # 2B | |||||
labels = labels.to(device) | |||||
images_pair = simclr_aug(images_pair) # simclr augmentation | |||||
_, outputs_aux = model(images_pair, simclr=True, penultimate=True) | |||||
simclr = normalize(outputs_aux['simclr']) # normalize | |||||
sim_matrix = get_similarity_matrix(simclr, multi_gpu=P.multi_gpu) | |||||
loss_sim = Supervised_NT_xent(sim_matrix, labels=labels, temperature=0.07, multi_gpu=P.multi_gpu) * P.sim_lambda | |||||
### total loss ### | |||||
loss = loss_sim | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
scheduler.step(epoch - 1 + n / len(loader)) | |||||
lr = optimizer.param_groups[0]['lr'] | |||||
batch_time.update(time.time() - check) | |||||
### Post-processing stuffs ### | |||||
simclr_norm = outputs_aux['simclr'].norm(dim=1).mean() | |||||
### Linear evaluation ### | |||||
outputs_linear_eval = linear(outputs_aux['penultimate'].detach()) | |||||
loss_linear = criterion(outputs_linear_eval, labels.repeat(2)) | |||||
linear_optim.zero_grad() | |||||
loss_linear.backward() | |||||
linear_optim.step() | |||||
### Log losses ### | |||||
losses['cls'].update(0, batch_size) | |||||
losses['sim'].update(loss_sim.item(), batch_size) | |||||
losses['simnorm'].update(simclr_norm.item(), batch_size) | |||||
if count % 50 == 0: | |||||
log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n' | |||||
'[LossC %f] [LossSim %f] [SimNorm %f]' % | |||||
(epoch, count, batch_time.value, data_time.value, lr, | |||||
losses['cls'].value, losses['sim'].value, losses['simnorm'].value)) | |||||
check = time.time() | |||||
log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f] [LossSim %f] [SimNorm %f]' % | |||||
(batch_time.average, data_time.average, | |||||
losses['cls'].average, losses['sim'].average, losses['simnorm'].average)) | |||||
if logger is not None: | |||||
logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch) | |||||
logger.scalar_summary('train/loss_sim', losses['sim'].average, epoch) | |||||
logger.scalar_summary('train/batch_time', batch_time.average, epoch) | |||||
logger.scalar_summary('train/simclr_norm', losses['simnorm'].average, epoch) |
import time | |||||
import torch.optim | |||||
import models.transform_layers as TL | |||||
from training.contrastive_loss import get_similarity_matrix, Supervised_NT_xent | |||||
from utils.utils import AverageMeter, normalize | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
hflip = TL.HorizontalFlipLayer().to(device) | |||||
def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None, | |||||
simclr_aug=None, linear=None, linear_optim=None): | |||||
# currently only support rotation shifting augmentation | |||||
assert simclr_aug is not None | |||||
assert P.sim_lambda == 1.0 | |||||
if logger is None: | |||||
log_ = print | |||||
else: | |||||
log_ = logger.log | |||||
batch_time = AverageMeter() | |||||
data_time = AverageMeter() | |||||
losses = dict() | |||||
losses['cls'] = AverageMeter() | |||||
losses['sim'] = AverageMeter() | |||||
check = time.time() | |||||
for n, (images, labels) in enumerate(loader): | |||||
model.train() | |||||
count = n * P.n_gpus # number of trained samples | |||||
data_time.update(time.time() - check) | |||||
check = time.time() | |||||
### SimCLR loss ### | |||||
if P.dataset != 'imagenet' and P.dataset != 'CNMC' and P.dataset != 'CNMC_grayscale': | |||||
batch_size = images.size(0) | |||||
images = images.to(device) | |||||
images1, images2 = hflip(images.repeat(2, 1, 1, 1)).chunk(2) # hflip | |||||
else: | |||||
batch_size = images[0].size(0) | |||||
images1, images2 = images[0].to(device), images[1].to(device) | |||||
#print("\nImages" + str(images.shape) + "\n") | |||||
images1 = torch.cat([torch.rot90(images1, rot, (2, 3)) for rot in range(4)]) # 4B | |||||
images2 = torch.cat([torch.rot90(images2, rot, (2, 3)) for rot in range(4)]) # 4B | |||||
images_pair = torch.cat([images1, images2], dim=0) # 8B | |||||
labels = labels.to(device) | |||||
rot_sim_labels = torch.cat([labels + P.n_classes * i for i in range(4)], dim=0) | |||||
rot_sim_labels = rot_sim_labels.to(device) | |||||
images_pair = simclr_aug(images_pair) # simclr augment | |||||
_, outputs_aux = model(images_pair, simclr=True, penultimate=True) | |||||
simclr = normalize(outputs_aux['simclr']) # normalize | |||||
sim_matrix = get_similarity_matrix(simclr, multi_gpu=P.multi_gpu) | |||||
loss_sim = Supervised_NT_xent(sim_matrix, labels=rot_sim_labels, | |||||
temperature=0.07, multi_gpu=P.multi_gpu) * P.sim_lambda | |||||
### total loss ### | |||||
loss = loss_sim | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
scheduler.step(epoch - 1 + n / len(loader)) | |||||
lr = optimizer.param_groups[0]['lr'] | |||||
batch_time.update(time.time() - check) | |||||
### Post-processing stuffs ### | |||||
penul_1 = outputs_aux['penultimate'][:batch_size] | |||||
penul_2 = outputs_aux['penultimate'][4 * batch_size: 5 * batch_size] | |||||
outputs_aux['penultimate'] = torch.cat([penul_1, penul_2]) # only use original rotation | |||||
### Linear evaluation ### | |||||
outputs_linear_eval = linear(outputs_aux['penultimate'].detach()) | |||||
loss_linear = criterion(outputs_linear_eval, labels.repeat(2)) | |||||
linear_optim.zero_grad() | |||||
loss_linear.backward() | |||||
linear_optim.step() | |||||
### Log losses ### | |||||
losses['cls'].update(0, batch_size) | |||||
losses['sim'].update(loss_sim.item(), batch_size) | |||||
if count % 50 == 0: | |||||
log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n' | |||||
'[LossC %f] [LossSim %f]' % | |||||
(epoch, count, batch_time.value, data_time.value, lr, | |||||
losses['cls'].value, losses['sim'].value)) | |||||
check = time.time() | |||||
log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f] [LossSim %f]' % | |||||
(batch_time.average, data_time.average, | |||||
losses['cls'].average, losses['sim'].average)) | |||||
if logger is not None: | |||||
logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch) | |||||
logger.scalar_summary('train/loss_sim', losses['sim'].average, epoch) | |||||
logger.scalar_summary('train/batch_time', batch_time.average, epoch) | |||||
def setup(mode, P): | |||||
fname = f'{P.dataset}_{P.model}_unsup_{mode}_{P.res}' | |||||
if mode == 'simclr': | |||||
from .simclr import train | |||||
elif mode == 'simclr_CSI': | |||||
from .simclr_CSI import train | |||||
fname += f'_shift_{P.shift_trans_type}_resize_factor{P.resize_factor}_color_dist{P.color_distort}' | |||||
if P.shift_trans_type == 'gauss': | |||||
fname += f'_gauss_sigma{P.gauss_sigma}' | |||||
elif P.shift_trans_type == 'randpers': | |||||
fname += f'_distortion_scale{P.distortion_scale}' | |||||
elif P.shift_trans_type == 'sharp': | |||||
fname += f'_sharpness_factor{P.sharpness_factor}' | |||||
elif P.shift_trans_type == 'sharp': | |||||
fname += f'_nmean_{P.noise_mean}_nstd_{P.noise_std}' | |||||
else: | |||||
raise NotImplementedError() | |||||
if P.one_class_idx is not None: | |||||
fname += f'_one_class_{P.one_class_idx}' | |||||
if P.suffix is not None: | |||||
fname += f'_{P.suffix}' | |||||
return train, fname | |||||
def update_comp_loss(loss_dict, loss_in, loss_out, loss_diff, batch_size): | |||||
loss_dict['pos'].update(loss_in, batch_size) | |||||
loss_dict['neg'].update(loss_out, batch_size) | |||||
loss_dict['diff'].update(loss_diff, batch_size) | |||||
def summary_comp_loss(logger, tag, loss_dict, epoch): | |||||
logger.scalar_summary(f'{tag}/pos', loss_dict['pos'].average, epoch) | |||||
logger.scalar_summary(f'{tag}/neg', loss_dict['neg'].average, epoch) | |||||
logger.scalar_summary(f'{tag}', loss_dict['diff'].average, epoch) | |||||
import time | |||||
import torch.optim | |||||
import models.transform_layers as TL | |||||
from training.contrastive_loss import get_similarity_matrix, NT_xent | |||||
from utils.utils import AverageMeter, normalize | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
hflip = TL.HorizontalFlipLayer().to(device) | |||||
def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None, | |||||
simclr_aug=None, linear=None, linear_optim=None): | |||||
assert simclr_aug is not None | |||||
assert P.sim_lambda == 1.0 | |||||
if logger is None: | |||||
log_ = print | |||||
else: | |||||
log_ = logger.log | |||||
batch_time = AverageMeter() | |||||
data_time = AverageMeter() | |||||
losses = dict() | |||||
losses['cls'] = AverageMeter() | |||||
losses['sim'] = AverageMeter() | |||||
check = time.time() | |||||
for n, (images, labels) in enumerate(loader): | |||||
model.train() | |||||
count = n * P.n_gpus # number of trained samples | |||||
data_time.update(time.time() - check) | |||||
check = time.time() | |||||
### SimCLR loss ### | |||||
if P.dataset != 'imagenet': | |||||
batch_size = images.size(0) | |||||
images = images.to(device) | |||||
images_pair = hflip(images.repeat(2, 1, 1, 1)) # 2B with hflip | |||||
else: | |||||
batch_size = images[0].size(0) | |||||
images1, images2 = images[0].to(device), images[1].to(device) | |||||
images_pair = torch.cat([images1, images2], dim=0) # 2B | |||||
labels = labels.to(device) | |||||
images_pair = simclr_aug(images_pair) # transform | |||||
_, outputs_aux = model(images_pair, simclr=True, penultimate=True) | |||||
simclr = normalize(outputs_aux['simclr']) # normalize | |||||
sim_matrix = get_similarity_matrix(simclr, multi_gpu=P.multi_gpu) | |||||
loss_sim = NT_xent(sim_matrix, temperature=0.5) * P.sim_lambda | |||||
### total loss ### | |||||
loss = loss_sim | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
scheduler.step(epoch - 1 + n / len(loader)) | |||||
lr = optimizer.param_groups[0]['lr'] | |||||
batch_time.update(time.time() - check) | |||||
### Post-processing stuffs ### | |||||
simclr_norm = outputs_aux['simclr'].norm(dim=1).mean() | |||||
### Linear evaluation ### | |||||
outputs_linear_eval = linear(outputs_aux['penultimate'].detach()) | |||||
loss_linear = criterion(outputs_linear_eval, labels.repeat(2)) | |||||
linear_optim.zero_grad() | |||||
loss_linear.backward() | |||||
linear_optim.step() | |||||
### Log losses ### | |||||
losses['cls'].update(0, batch_size) | |||||
losses['sim'].update(loss_sim.item(), batch_size) | |||||
if count % 50 == 0: | |||||
log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n' | |||||
'[LossC %f] [LossSim %f]' % | |||||
(epoch, count, batch_time.value, data_time.value, lr, | |||||
losses['cls'].value, losses['sim'].value)) | |||||
check = time.time() | |||||
log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f] [LossSim %f]' % | |||||
(batch_time.average, data_time.average, | |||||
losses['cls'].average, losses['sim'].average)) | |||||
if logger is not None: | |||||
logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch) | |||||
logger.scalar_summary('train/loss_sim', losses['sim'].average, epoch) | |||||
logger.scalar_summary('train/batch_time', batch_time.average, epoch) |
import time | |||||
import torch.optim | |||||
import models.transform_layers as TL | |||||
from training.contrastive_loss import get_similarity_matrix, NT_xent | |||||
from utils.utils import AverageMeter, normalize | |||||
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu") | |||||
hflip = TL.HorizontalFlipLayer().to(device) | |||||
def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None, | |||||
simclr_aug=None, linear=None, linear_optim=None): | |||||
assert simclr_aug is not None | |||||
assert P.sim_lambda == 1.0 # to avoid mistake | |||||
assert P.K_shift > 1 | |||||
if logger is None: | |||||
log_ = print | |||||
else: | |||||
log_ = logger.log | |||||
batch_time = AverageMeter() | |||||
data_time = AverageMeter() | |||||
losses = dict() | |||||
losses['cls'] = AverageMeter() | |||||
losses['sim'] = AverageMeter() | |||||
losses['shift'] = AverageMeter() | |||||
check = time.time() | |||||
for n, (images, labels) in enumerate(loader): | |||||
model.train() | |||||
count = n * P.n_gpus # number of trained samples | |||||
data_time.update(time.time() - check) | |||||
check = time.time() | |||||
### SimCLR loss ### | |||||
if P.dataset != 'imagenet' and P.dataset != 'CNMC' and P.dataset != 'CNMC_grayscale': | |||||
batch_size = images.size(0) | |||||
images = images.to(device) | |||||
images1, images2 = hflip(images.repeat(2, 1, 1, 1)).chunk(2) # hflip | |||||
else: | |||||
batch_size = images[0].size(0) | |||||
images1, images2 = images[0].to(device), images[1].to(device) | |||||
labels = labels.to(device) | |||||
images1 = torch.cat([P.shift_trans(images1, k) for k in range(P.K_shift)]) | |||||
images2 = torch.cat([P.shift_trans(images2, k) for k in range(P.K_shift)]) | |||||
shift_labels = torch.cat([torch.ones_like(labels) * k for k in range(P.K_shift)], 0) # B -> 4B | |||||
shift_labels = shift_labels.repeat(2) | |||||
images_pair = torch.cat([images1, images2], dim=0) # 8B | |||||
images_pair = simclr_aug(images_pair) # transform | |||||
_, outputs_aux = model(images_pair, simclr=True, penultimate=True, shift=True) | |||||
simclr = normalize(outputs_aux['simclr']) # normalize | |||||
sim_matrix = get_similarity_matrix(simclr, multi_gpu=P.multi_gpu) | |||||
loss_sim = NT_xent(sim_matrix, temperature=0.5) * P.sim_lambda | |||||
loss_shift = criterion(outputs_aux['shift'], shift_labels) | |||||
### total loss ### | |||||
loss = loss_sim + loss_shift | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
scheduler.step(epoch - 1 + n / len(loader)) | |||||
lr = optimizer.param_groups[0]['lr'] | |||||
batch_time.update(time.time() - check) | |||||
### Post-processing stuffs ### | |||||
simclr_norm = outputs_aux['simclr'].norm(dim=1).mean() | |||||
penul_1 = outputs_aux['penultimate'][:batch_size] | |||||
penul_2 = outputs_aux['penultimate'][P.K_shift * batch_size: (P.K_shift + 1) * batch_size] | |||||
outputs_aux['penultimate'] = torch.cat([penul_1, penul_2]) # only use original rotation | |||||
### Linear evaluation ### | |||||
outputs_linear_eval = linear(outputs_aux['penultimate'].detach()) | |||||
loss_linear = criterion(outputs_linear_eval, labels.repeat(2)) | |||||
linear_optim.zero_grad() | |||||
loss_linear.backward() | |||||
linear_optim.step() | |||||
losses['cls'].update(0, batch_size) | |||||
losses['sim'].update(loss_sim.item(), batch_size) | |||||
losses['shift'].update(loss_shift.item(), batch_size) | |||||
if count % 50 == 0: | |||||
log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n' | |||||
'[LossC %f] [LossSim %f] [LossShift %f]' % | |||||
(epoch, count, batch_time.value, data_time.value, lr, | |||||
losses['cls'].value, losses['sim'].value, losses['shift'].value)) | |||||
log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f] [LossSim %f] [LossShift %f]' % | |||||
(batch_time.average, data_time.average, | |||||
losses['cls'].average, losses['sim'].average, losses['shift'].average)) | |||||
if logger is not None: | |||||
logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch) | |||||
logger.scalar_summary('train/loss_sim', losses['sim'].average, epoch) | |||||
logger.scalar_summary('train/loss_shift', losses['shift'].average, epoch) | |||||
logger.scalar_summary('train/batch_time', batch_time.average, epoch) | |||||