This commit is contained in:
Artur Feoktistov 2022-04-29 19:26:47 +02:00
commit d1ce7b933f
110 changed files with 17469 additions and 0 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

176
README.md Normal file
View File

@ -0,0 +1,176 @@
# CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances
Official PyTorch implementation of
["**CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances**"](
https://arxiv.org/abs/2007.08176) (NeurIPS 2020) by
[Jihoon Tack*](https://jihoontack.github.io),
[Sangwoo Mo*](https://sites.google.com/view/sangwoomo),
[Jongheon Jeong](https://sites.google.com/view/jongheonj),
and [Jinwoo Shin](http://alinlab.kaist.ac.kr/shin.html).
<p align="center">
<img src=figures/shifting_transformations.png width="900">
</p>
## 1. Requirements
### Environments
Currently, requires following packages
- python 3.6+
- torch 1.4+
- torchvision 0.5+
- CUDA 10.1+
- scikit-learn 0.22+
- tensorboard 2.0+
- [torchlars](https://github.com/kakaobrain/torchlars) == 0.1.2
- [pytorch-gradual-warmup-lr](https://github.com/ildoonet/pytorch-gradual-warmup-lr) packages
- [apex](https://github.com/NVIDIA/apex) == 0.1
- [diffdist](https://github.com/ag14774/diffdist) == 0.1
### Datasets
For CIFAR, please download the following datasets to `~/data`.
* [LSUN_resize](https://www.dropbox.com/s/moqh2wh8696c3yl/LSUN_resize.tar.gz),
[ImageNet_resize](https://www.dropbox.com/s/kp3my3412u5k9rl/Imagenet_resize.tar.gz)
* [LSUN_fix](https://drive.google.com/file/d/1KVWj9xpHfVwGcErH5huVujk9snhEGOxE/view?usp=sharing),
[ImageNet_fix](https://drive.google.com/file/d/1sO_-noq10mmziB1ECDyNhD5T4u5otyKA/view?usp=sharing)
For ImageNet-30, please download the following datasets to `~/data`.
* [ImageNet-30-train](https://drive.google.com/file/d/1B5c39Fc3haOPzlehzmpTLz6xLtGyKEy4/view),
[ImageNet-30-test](https://drive.google.com/file/d/13xzVuQMEhSnBRZr-YaaO08coLU2dxAUq/view)
* [CUB-200](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html),
[Stanford Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/),
[Oxford Pets](https://www.robots.ox.ac.uk/~vgg/data/pets/),
[Oxford flowers](https://www.robots.ox.ac.uk/~vgg/data/flowers/),
[Food-101](https://www.kaggle.com/dansbecker/food-101),
[Places-365](http://data.csail.mit.edu/places/places365/val_256.tar),
[Caltech-256](https://www.kaggle.com/jessicali9530/caltech256),
[DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/)
For Food-101, remove hotdog class to avoid overlap.
## 2. Training
Currently, all code examples are assuming distributed launch with 4 multi GPUs.
To run the code with single GPU, remove `-m torch.distributed.launch --nproc_per_node=4`.
### Unlabeled one-class & multi-class
To train unlabeled one-class & multi-class models in the paper, run this command:
```train
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py --dataset <DATASET> --model <NETWORK> --mode simclr_CSI --shift_trans_type rotation --batch_size 32 --one_class_idx <One-Class-Index>
```
> Option --one_class_idx denotes the in-distribution of one-class training.
> For multi-class training, set --one_class_idx as None.
> To run SimCLR simply change --mode to simclr.
> Total batch size should be 512 = 4 (GPU) * 32 (--batch_size option) * 4 (cardinality of shifted transformation set).
### Labeled multi-class
To train labeled multi-class model (confidence calibrated classifier) in the paper, run this command:
```train
# Representation train
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py --dataset <DATASET> --model <NETWORK> --mode sup_simclr_CSI --shift_trans_type rotation --batch_size 32 --epoch 700
# Linear layer train
python train.py --mode sup_CSI_linear --dataset <DATASET> --model <NETWORK> --batch_size 32 --epoch 100 --shift_trans_type rotation --load_path <MODEL_PATH>
```
> To run SupCLR simply change --mode to sup_simclr, sup_linear for representation training and linear layer training respectively.
> Total batch size should be same as above. Currently only supports rotation for shifted transformation.
## 3. Evaluation
We provide the checkpoint of the CSI pre-trained model. Download the checkpoint from the following link:
- One-class CIFAR-10: [ResNet-18](https://drive.google.com/drive/folders/1z02i0G_lzrZe0NwpH-tnjpO8pYHV7mE9?usp=sharing)
- Unlabeled (multi-class) CIFAR-10: [ResNet-18](https://drive.google.com/file/d/1yUq6Si6hWaMa1uYyLDHk0A4BrPIa8ECV/view?usp=sharing)
- Unlabeled (multi-class) ImageNet-30: [ResNet-18](https://drive.google.com/file/d/1KucQWSik8RyoJgU-fz8XLmCWhvMOP7fT/view?usp=sharing)
- Labeled (multi-class) CIFAR-10: [ResNet-18](https://drive.google.com/file/d/1rW2-0MJEzPHLb_PAW-LvCivHt-TkDpRO/view?usp=sharing)
### Unlabeled one-class & multi-class
To evaluate my model on unlabeled one-class & multi-class out-of-distribution (OOD) detection setting, run this command:
```eval
python eval.py --mode ood_pre --dataset <DATASET> --model <NETWORK> --ood_score CSI --shift_trans_type rotation --print_score --ood_samples 10 --resize_factor 0.54 --resize_fix --one_class_idx <One-Class-Index> --load_path <MODEL_PATH>
```
> Option --one_class_idx denotes the in-distribution of one-class evaluation.
> For multi-class evaluation, set --one_class_idx as None.
> The resize_factor & resize fix option fix the cropping size of RandomResizedCrop().
> For SimCLR evaluation, change --ood_score to simclr.
### Labeled multi-class
To evaluate my model on labeled multi-class accuracy, ECE, OOD detection setting, run this command:
```eval
# OOD AUROC
python eval.py --mode ood --ood_score baseline_marginalized --print_score --dataset <DATASET> --model <NETWORK> --shift_trans_type rotation --load_path <MODEL_PATH>
# Accuray & ECE
python eval.py --mode test_marginalized_acc --dataset <DATASET> --model <NETWORK> --shift_trans_type rotation --load_path <MODEL_PATH>
```
> This option is for marginalized inference.
> For single inference (also used for SupCLR) change --ood_score baseline in first command,
> and --mode test_acc in second command.
## 4. Results
Our model achieves the following performance on:
### One-Class Out-of-Distribution Detection
| Method | Dataset | AUROC (Mean) |
| --------------|------------------ | --------------|
| SimCLR | CIFAR-10-OC | 87.9% |
| Rot+Trans | CIFAR-10-OC | 90.0% |
| CSI (ours) | CIFAR-10-OC | 94.3% |
We only show CIFAR-10 one-class result in this repo. For other setting, please see our paper.
### Unlabeled Multi-Class Out-of-Distribution Detection
| Method | Dataset | OOD Dataset | AUROC (Mean) |
| --------------|------------------ |---------------|--------------|
| Rot+Trans | CIFAR-10 | CIFAR-100 | 82.5% |
| CSI (ours) | CIFAR-10 | CIFAR-100 | 89.3% |
We only show CIFAR-10 to CIFAR-100 OOD detection result in this repo. For other OOD dataset results, see our paper.
### Labeled Multi-Class Result
| Method | Dataset | OOD Dataset | Acc | ECE | AUROC (Mean) |
| ---------------- |------------------ |---------------|-------|-------|--------------|
| SupCLR | CIFAR-10 | CIFAR-100 | 93.9% | 5.54% | 88.3% |
| CSI (ours) | CIFAR-10 | CIFAR-100 | 94.8% | 4.24% | 90.6% |
| CSI-ensem (ours) | CIFAR-10 | CIFAR-100 | 96.0% | 3.64% | 92.3% |
We only show CIFAR-10 with CIFAR-100 as OOD in this repo. For other dataset results, please see our paper.
## 5. New OOD dataset
<p align="center">
<img src=figures/fixed_ood_benchmarks.png width="600">
</p>
We find that current benchmark datasets for OOD detection, are visually far from in-distribution datasets (e.g. CIFAR).
To address this issue, we provide new datasets for OOD detection evaluation:
[LSUN_fix](https://drive.google.com/file/d/1KVWj9xpHfVwGcErH5huVujk9snhEGOxE/view?usp=sharing),
[ImageNet_fix](https://drive.google.com/file/d/1sO_-noq10mmziB1ECDyNhD5T4u5otyKA/view?usp=sharing).
See the above figure for the visualization of current benchmark and our dataset.
To generate OOD datasets, run the following codes inside the `./datasets` folder:
```OOD dataset generation
# ImageNet FIX generation code
python imagenet_fix_preprocess.py
# LSUN FIX generation code
python lsun_fix_preprocess.py
```
## Citation
```
@inproceedings{tack2020csi,
title={CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances},
author={Jihoon Tack and Sangwoo Mo and Jongheon Jeong and Jinwoo Shin},
booktitle={Advances in Neural Information Processing Systems},
year={2020}
}
```

119
common/LARS.py Normal file
View File

@ -0,0 +1,119 @@
"""
References:
- https://github.com/PyTorchLightning/PyTorch-Lightning-Bolts/blob/master/pl_bolts/optimizers/lars_scheduling.py
- https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
- https://arxiv.org/pdf/1708.03888.pdf
- https://github.com/noahgolmant/pytorch-lars/blob/master/lars.py
"""
import torch
from .wrapper import OptimWrapper
# from torchlars._adaptive_lr import compute_adaptive_lr # Impossible to build extensions
__all__ = ["LARS"]
class LARS(OptimWrapper):
"""Implements 'LARS (Layer-wise Adaptive Rate Scaling)'__ as Optimizer a
:class:`~torch.optim.Optimizer` wrapper.
__ : https://arxiv.org/abs/1708.03888
Wraps an arbitrary optimizer like :class:`torch.optim.SGD` to use LARS. If
you want to the same performance obtained with small-batch training when
you use large-batch training, LARS will be helpful::
Args:
optimizer (Optimizer):
optimizer to wrap
eps (float, optional):
epsilon to help with numerical stability while calculating the
adaptive learning rate
trust_coef (float, optional):
trust coefficient for calculating the adaptive learning rate
Example::
base_optimizer = optim.SGD(model.parameters(), lr=0.1)
optimizer = LARS(optimizer=base_optimizer)
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
"""
def __init__(self, optimizer, trust_coef=0.02, clip=True, eps=1e-8):
if eps < 0.0:
raise ValueError("invalid epsilon value: , %f" % eps)
if trust_coef < 0.0:
raise ValueError("invalid trust coefficient: %f" % trust_coef)
self.optim = optimizer
self.eps = eps
self.trust_coef = trust_coef
self.clip = clip
def __getstate__(self):
self.optim.__get
lars_dict = {}
lars_dict["trust_coef"] = self.trust_coef
lars_dict["clip"] = self.clip
lars_dict["eps"] = self.eps
return (self.optim, lars_dict)
def __setstate__(self, state):
self.optim, lars_dict = state
self.trust_coef = lars_dict["trust_coef"]
self.clip = lars_dict["clip"]
self.eps = lars_dict["eps"]
@torch.no_grad()
def step(self, closure=None):
weight_decays = []
for group in self.optim.param_groups:
weight_decay = group.get("weight_decay", 0)
weight_decays.append(weight_decay)
# reset weight decay
group["weight_decay"] = 0
# update the parameters
for p in group["params"]:
if p.grad is not None:
self.update_p(p, group, weight_decay)
# update the optimizer
self.optim.step(closure=closure)
# return weight decay control to optimizer
for group_idx, group in enumerate(self.optim.param_groups):
group["weight_decay"] = weight_decays[group_idx]
def update_p(self, p, group, weight_decay):
# calculate new norms
p_norm = torch.norm(p.data)
g_norm = torch.norm(p.grad.data)
if p_norm != 0 and g_norm != 0:
# calculate new lr
divisor = g_norm + p_norm * weight_decay + self.eps
adaptive_lr = (self.trust_coef * p_norm) / divisor
# clip lr
if self.clip:
adaptive_lr = min(adaptive_lr / group["lr"], 1)
# update params with clipped lr
p.grad.data += weight_decay * p.data
p.grad.data *= adaptive_lr
from torch.optim import SGD
from pylot.util import delegates, separate_kwargs
class SGDLARS(LARS):
@delegates(to=LARS.__init__)
@delegates(to=SGD.__init__, keep=True, but=["eps", "trust_coef"])
def __init__(self, params, **kwargs):
sgd_kwargs, lars_kwargs = separate_kwargs(kwargs, SGD.__init__)
optim = SGD(params, **sgd_kwargs)
super().__init__(optim, **lars_kwargs)

0
common/__init__.py Normal file
View File

BIN
common/__init__.pyc Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

114
common/common.py Normal file
View File

@ -0,0 +1,114 @@
from argparse import ArgumentParser
def parse_args(default=False):
"""Command-line argument parser for training."""
parser = ArgumentParser(description='Pytorch implementation of CSI')
parser.add_argument('--dataset', help='Dataset',
choices=['cifar10', 'cifar100', 'imagenet', 'CNMC', 'CNMC_grayscale'], type=str)
parser.add_argument('--one_class_idx', help='None: multi-class, Not None: one-class',
default=None, type=int)
parser.add_argument('--model', help='Model',
choices=['resnet18', 'resnet18_imagenet'], type=str)
parser.add_argument('--mode', help='Training mode',
default='simclr', type=str)
parser.add_argument('--simclr_dim', help='Dimension of simclr layer',
default=128, type=int)
parser.add_argument('--shift_trans_type', help='shifting transformation type', default='none',
choices=['rotation', 'cutperm', 'blur', 'randpers', 'sharp', 'blur_randpers',
'blur_sharp', 'randpers_sharp', 'blur_randpers_sharp', 'noise', 'none'], type=str)
parser.add_argument("--local_rank", type=int,
default=0, help='Local rank for distributed learning')
parser.add_argument('--resume_path', help='Path to the resume checkpoint',
default=None, type=str)
parser.add_argument('--load_path', help='Path to the loading checkpoint',
default=None, type=str)
parser.add_argument("--no_strict", help='Do not strictly load state_dicts',
action='store_true')
parser.add_argument('--suffix', help='Suffix for the log dir',
default=None, type=str)
parser.add_argument('--error_step', help='Epoch steps to compute errors',
default=5, type=int)
parser.add_argument('--save_step', help='Epoch steps to save models',
default=10, type=int)
##### Training Configurations #####
parser.add_argument('--epochs', help='Epochs',
default=1000, type=int)
parser.add_argument('--optimizer', help='Optimizer',
choices=['sgd', 'lars'],
default='lars', type=str)
parser.add_argument('--lr_scheduler', help='Learning rate scheduler',
choices=['step_decay', 'cosine'],
default='cosine', type=str)
parser.add_argument('--warmup', help='Warm-up epochs',
default=10, type=int)
parser.add_argument('--lr_init', help='Initial learning rate',
default=1e-1, type=float)
parser.add_argument('--weight_decay', help='Weight decay',
default=1e-6, type=float)
parser.add_argument('--batch_size', help='Batch size',
default=128, type=int)
parser.add_argument('--test_batch_size', help='Batch size for test loader',
default=100, type=int)
parser.add_argument('--blur_sigma', help='Distortion grade',
default=2.0, type=float)
parser.add_argument('--color_distort', help='Color distortion grade',
default=0.5, type=float)
parser.add_argument('--distortion_scale', help='Perspective distortion grade',
default=0.6, type=float)
parser.add_argument('--sharpness_factor', help='Sharpening or blurring factor of image. '
'Can be any non negative number. 0 gives a blurred image, '
'1 gives the original image while 2 increases the sharpness '
'by a factor of 2.',
default=2, type=float)
parser.add_argument('--noise_mean', help='mean',
default=0, type=float)
parser.add_argument('--noise_std', help='std',
default=0.3, type=float)
##### Objective Configurations #####
parser.add_argument('--sim_lambda', help='Weight for SimCLR loss',
default=1.0, type=float)
parser.add_argument('--temperature', help='Temperature for similarity',
default=0.5, type=float)
##### Evaluation Configurations #####
parser.add_argument("--ood_dataset", help='Datasets for OOD detection',
default=None, nargs="*", type=str)
parser.add_argument("--ood_score", help='score function for OOD detection',
default=['norm_mean'], nargs="+", type=str)
parser.add_argument("--ood_layer", help='layer for OOD scores',
choices=['penultimate', 'simclr', 'shift'],
default=['simclr', 'shift'], nargs="+", type=str)
parser.add_argument("--ood_samples", help='number of samples to compute OOD score',
default=1, type=int)
parser.add_argument("--ood_batch_size", help='batch size to compute OOD score',
default=100, type=int)
parser.add_argument("--resize_factor", help='resize scale is sampled from [resize_factor, 1.0]',
default=0.08, type=float)
parser.add_argument("--resize_fix", help='resize scale is fixed to resize_factor (not (resize_factor, 1.0])',
action='store_true')
parser.add_argument("--print_score", help='print quantiles of ood score',
action='store_true')
parser.add_argument("--save_score", help='save ood score for plotting histogram',
action='store_true')
##### Process configuration option #####
parser.add_argument("--proc_step", help='choose process to initiate.',
choices=['eval', 'train'],
default=None, type=str)
parser.add_argument("--res", help='resolution of dataset',
default="32px", type=str)
if default:
return parser.parse_args('') # empty string
else:
return parser.parse_args()

81
common/eval.py Normal file
View File

@ -0,0 +1,81 @@
from copy import deepcopy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from common.common import parse_args
import models.classifier as C
from datasets import get_dataset, get_superclass_list, get_subclass_dataset
P = parse_args()
### Set torch device ###
P.n_gpus = torch.cuda.device_count()
assert P.n_gpus <= 1 # no multi GPU
P.multi_gpu = False
if torch.cuda.is_available():
torch.cuda.set_device(P.local_rank)
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
### Initialize dataset ###
ood_eval = P.mode == 'ood_pre'
if P.dataset == 'imagenet' and ood_eval or P.dataset == 'CNMC' and ood_eval or P.dataset == 'CNMC_grayscale' and ood_eval:
P.batch_size = 1
P.test_batch_size = 1
train_set, test_set, image_size, n_classes = get_dataset(P, dataset=P.dataset, eval=ood_eval)
P.image_size = image_size
P.n_classes = n_classes
if P.one_class_idx is not None:
cls_list = get_superclass_list(P.dataset)
P.n_superclasses = len(cls_list)
full_test_set = deepcopy(test_set) # test set of full classes
train_set = get_subclass_dataset(train_set, classes=cls_list[P.one_class_idx])
test_set = get_subclass_dataset(test_set, classes=cls_list[P.one_class_idx])
kwargs = {'pin_memory': False, 'num_workers': 2}
train_loader = DataLoader(train_set, shuffle=True, batch_size=P.batch_size, **kwargs)
test_loader = DataLoader(test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs)
if P.ood_dataset is None:
if P.one_class_idx is not None:
P.ood_dataset = list(range(P.n_superclasses))
P.ood_dataset.pop(P.one_class_idx)
elif P.dataset == 'cifar10':
P.ood_dataset = ['svhn', 'lsun_resize', 'imagenet_resize', 'lsun_fix', 'imagenet_fix', 'cifar100', 'interp']
elif P.dataset == 'imagenet':
P.ood_dataset = ['cub', 'stanford_dogs', 'flowers102', 'places365', 'food_101', 'caltech_256', 'dtd', 'pets']
ood_test_loader = dict()
for ood in P.ood_dataset:
if ood == 'interp':
ood_test_loader[ood] = None # dummy loader
continue
if P.one_class_idx is not None:
ood_test_set = get_subclass_dataset(full_test_set, classes=cls_list[ood])
ood = f'one_class_{ood}' # change save name
else:
ood_test_set = get_dataset(P, dataset=ood, test_only=True, image_size=P.image_size, eval=ood_eval)
ood_test_loader[ood] = DataLoader(ood_test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs)
### Initialize model ###
simclr_aug = C.get_simclr_augmentation(P, image_size=P.image_size).to(device)
P.shift_trans, P.K_shift = C.get_shift_module(P, eval=True)
P.shift_trans = P.shift_trans.to(device)
model = C.get_classifier(P.model, n_classes=P.n_classes).to(device)
model = C.get_shift_classifer(model, P.K_shift).to(device)
criterion = nn.CrossEntropyLoss().to(device)
if P.load_path is not None:
checkpoint = torch.load(P.load_path)
model.load_state_dict(checkpoint, strict=not P.no_strict)

148
common/train.py Normal file
View File

@ -0,0 +1,148 @@
from copy import deepcopy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from common.common import parse_args
import models.classifier as C
from datasets import get_dataset, get_superclass_list, get_subclass_dataset
from utils.utils import load_checkpoint
P = parse_args()
### Set torch device ###
if torch.cuda.is_available():
torch.cuda.set_device(P.local_rank)
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
P.n_gpus = torch.cuda.device_count()
if P.n_gpus > 1:
import apex
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
P.multi_gpu = True
torch.distributed.init_process_group(
'nccl',
init_method='env://',
world_size=P.n_gpus,
rank=P.local_rank,
)
else:
P.multi_gpu = False
### only use one ood_layer while training
P.ood_layer = P.ood_layer[0]
### Initialize dataset ###
train_set, test_set, image_size, n_classes = get_dataset(P, dataset=P.dataset)
P.image_size = image_size
P.n_classes = n_classes
if P.one_class_idx is not None:
cls_list = get_superclass_list(P.dataset)
P.n_superclasses = len(cls_list)
full_test_set = deepcopy(test_set) # test set of full classes
train_set = get_subclass_dataset(train_set, classes=cls_list[P.one_class_idx])
test_set = get_subclass_dataset(test_set, classes=cls_list[P.one_class_idx])
kwargs = {'pin_memory': False, 'num_workers': 2}
if P.multi_gpu:
train_sampler = DistributedSampler(train_set, num_replicas=P.n_gpus, rank=P.local_rank)
test_sampler = DistributedSampler(test_set, num_replicas=P.n_gpus, rank=P.local_rank)
train_loader = DataLoader(train_set, sampler=train_sampler, batch_size=P.batch_size, **kwargs)
test_loader = DataLoader(test_set, sampler=test_sampler, batch_size=P.test_batch_size, **kwargs)
else:
train_loader = DataLoader(train_set, shuffle=True, batch_size=P.batch_size, **kwargs)
test_loader = DataLoader(test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs)
if P.ood_dataset is None:
if P.one_class_idx is not None:
P.ood_dataset = list(range(P.n_superclasses))
P.ood_dataset.pop(P.one_class_idx)
elif P.dataset == 'cifar10':
P.ood_dataset = ['svhn', 'lsun_resize', 'imagenet_resize', 'lsun_fix', 'imagenet_fix', 'cifar100', 'interp']
elif P.dataset == 'imagenet':
P.ood_dataset = ['cub', 'stanford_dogs', 'flowers102']
ood_test_loader = dict()
for ood in P.ood_dataset:
if ood == 'interp':
ood_test_loader[ood] = None # dummy loader
continue
if P.one_class_idx is not None:
ood_test_set = get_subclass_dataset(full_test_set, classes=cls_list[ood])
ood = f'one_class_{ood}' # change save name
else:
ood_test_set = get_dataset(P, dataset=ood, test_only=True, image_size=P.image_size)
if P.multi_gpu:
ood_sampler = DistributedSampler(ood_test_set, num_replicas=P.n_gpus, rank=P.local_rank)
ood_test_loader[ood] = DataLoader(ood_test_set, sampler=ood_sampler, batch_size=P.test_batch_size, **kwargs)
else:
ood_test_loader[ood] = DataLoader(ood_test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs)
### Initialize model ###
simclr_aug = C.get_simclr_augmentation(P, image_size=P.image_size).to(device)
P.shift_trans, P.K_shift = C.get_shift_module(P, eval=True)
P.shift_trans = P.shift_trans.to(device)
model = C.get_classifier(P.model, n_classes=P.n_classes).to(device)
model = C.get_shift_classifer(model, P.K_shift).to(device)
criterion = nn.CrossEntropyLoss().to(device)
if P.optimizer == 'sgd':
optimizer = optim.SGD(model.parameters(), lr=P.lr_init, momentum=0.9, weight_decay=P.weight_decay)
lr_decay_gamma = 0.1
elif P.optimizer == 'lars':
from torchlars import LARS
base_optimizer = optim.SGD(model.parameters(), lr=P.lr_init, momentum=0.9, weight_decay=P.weight_decay)
optimizer = LARS(base_optimizer, eps=1e-8, trust_coef=0.001)
lr_decay_gamma = 0.1
else:
raise NotImplementedError()
if P.lr_scheduler == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, P.epochs)
elif P.lr_scheduler == 'step_decay':
milestones = [int(0.5 * P.epochs), int(0.75 * P.epochs)]
scheduler = lr_scheduler.MultiStepLR(optimizer, gamma=lr_decay_gamma, milestones=milestones)
else:
raise NotImplementedError()
from training.scheduler import GradualWarmupScheduler
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=10.0, total_epoch=P.warmup, after_scheduler=scheduler)
if P.resume_path is not None:
resume = True
model_state, optim_state, config = load_checkpoint(P.resume_path, mode='last')
model.load_state_dict(model_state, strict=not P.no_strict)
optimizer.load_state_dict(optim_state)
start_epoch = config['epoch']
best = config['best']
error = 100.0
else:
resume = False
start_epoch = 1
best = 100.0
error = 100.0
if P.mode == 'sup_linear' or P.mode == 'sup_CSI_linear':
assert P.load_path is not None
checkpoint = torch.load(P.load_path)
model.load_state_dict(checkpoint, strict=not P.no_strict)
if P.multi_gpu:
simclr_aug = apex.parallel.DistributedDataParallel(simclr_aug, delay_allreduce=True)
model = apex.parallel.convert_syncbn_model(model)
model = apex.parallel.DistributedDataParallel(model, delay_allreduce=True)

BIN
data/ImageNet_FIX.tar.gz Normal file

Binary file not shown.

BIN
data/Imagenet_resize.tar.gz Normal file

Binary file not shown.

BIN
data/LSUN_FIX.tar.gz Normal file

Binary file not shown.

BIN
data/LSUN_resize.tar.gz Normal file

Binary file not shown.

2
datasets/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from datasets.datasets import get_dataset, get_superclass_list, get_subclass_dataset

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

361
datasets/datasets.py Normal file
View File

@ -0,0 +1,361 @@
import os
import numpy as np
import torch
from torch.utils.data.dataset import Subset
from torchvision import datasets, transforms
from utils.utils import set_random_seed
DATA_PATH = '~/data/'
IMAGENET_PATH = '~/data/ImageNet'
CNMC_PATH = r'~/data/CSI/CNMC_orig'
CNMC_GRAY_PATH = r'~/data/CSI/CNMC_orig_gray'
CNMC_ROT4_PATH = r'~/data/CSI/CNMC_rotated_4'
CIFAR10_SUPERCLASS = list(range(10)) # one class
IMAGENET_SUPERCLASS = list(range(30)) # one class
CNMC_SUPERCLASS = list(range(2)) # one class
STD_RES = 450
STD_CENTER_CROP = 300
CIFAR100_SUPERCLASS = [
[4, 31, 55, 72, 95],
[1, 33, 67, 73, 91],
[54, 62, 70, 82, 92],
[9, 10, 16, 29, 61],
[0, 51, 53, 57, 83],
[22, 25, 40, 86, 87],
[5, 20, 26, 84, 94],
[6, 7, 14, 18, 24],
[3, 42, 43, 88, 97],
[12, 17, 38, 68, 76],
[23, 34, 49, 60, 71],
[15, 19, 21, 32, 39],
[35, 63, 64, 66, 75],
[27, 45, 77, 79, 99],
[2, 11, 36, 46, 98],
[28, 30, 44, 78, 93],
[37, 50, 65, 74, 80],
[47, 52, 56, 59, 96],
[8, 13, 48, 58, 90],
[41, 69, 81, 85, 89],
]
class MultiDataTransform(object):
def __init__(self, transform):
self.transform1 = transform
self.transform2 = transform
def __call__(self, sample):
x1 = self.transform1(sample)
x2 = self.transform2(sample)
return x1, x2
class MultiDataTransformList(object):
def __init__(self, transform, clean_trasform, sample_num):
self.transform = transform
self.clean_transform = clean_trasform
self.sample_num = sample_num
def __call__(self, sample):
set_random_seed(0)
sample_list = []
for i in range(self.sample_num):
sample_list.append(self.transform(sample))
return sample_list, self.clean_transform(sample)
def get_transform(image_size=None):
# Note: data augmentation is implemented in the layers
# Hence, we only define the identity transformation here
if image_size: # use pre-specified image size
train_transform = transforms.Compose([
transforms.Resize((image_size[0], image_size[1])),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.Resize((image_size[0], image_size[1])),
transforms.ToTensor(),
])
else: # use default image size
train_transform = transforms.Compose([
transforms.ToTensor(),
])
test_transform = transforms.ToTensor()
return train_transform, test_transform
def get_subset_with_len(dataset, length, shuffle=False):
set_random_seed(0)
dataset_size = len(dataset)
index = np.arange(dataset_size)
if shuffle:
np.random.shuffle(index)
index = torch.from_numpy(index[0:length])
subset = Subset(dataset, index)
assert len(subset) == length
return subset
def get_transform_imagenet():
train_transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
train_transform = MultiDataTransform(train_transform)
return train_transform, test_transform
def get_transform_cnmc(res, center_crop_size):
train_transform = transforms.Compose([
transforms.Resize(res),
transforms.CenterCrop(center_crop_size),
transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.Resize(res),
transforms.CenterCrop(center_crop_size),
transforms.ToTensor(),
])
train_transform = MultiDataTransform(train_transform)
return train_transform, test_transform
def get_dataset(P, dataset, test_only=False, image_size=None, download=False, eval=False):
if P.res != '':
res = int(P.res.replace('px', ''))
size_factor = int(STD_RES/res) # always remove same portion
center_crop_size = int(STD_CENTER_CROP/size_factor) # remove black border
if dataset in ['CNMC', 'CNMC_grayscale', 'CNMC_ROT4_PATH']:
if eval:
train_transform, test_transform = get_simclr_eval_transform_cnmc(P.ood_samples,
P.resize_factor, P.resize_fix, res, center_crop_size)
else:
train_transform, test_transform = get_transform_cnmc(res, center_crop_size)
elif dataset in ['imagenet', 'cub', 'stanford_dogs', 'flowers102',
'places365', 'food_101', 'caltech_256', 'dtd', 'pets']:
if eval:
train_transform, test_transform = get_simclr_eval_transform_imagenet(P.ood_samples,
P.resize_factor, P.resize_fix)
else:
train_transform, test_transform = get_transform_imagenet()
else:
train_transform, test_transform = get_transform(image_size=image_size)
if dataset == 'CNMC':
image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3
n_classes = 2
train_dir = os.path.join(CNMC_PATH, '0_training')
test_dir = os.path.join(CNMC_PATH, '1_validation')
train_set = datasets.ImageFolder(train_dir, transform=train_transform)
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
elif dataset == 'CNMC_grayscale':
image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3
n_classes = 2
train_dir = os.path.join(CNMC_GRAY_PATH, '0_training')
test_dir = os.path.join(CNMC_GRAY_PATH, '1_validation')
train_set = datasets.ImageFolder(train_dir, transform=train_transform)
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
elif dataset == 'cifar10':
image_size = (32, 32, 3)
n_classes = 10
train_set = datasets.CIFAR10(DATA_PATH, train=True, download=download, transform=train_transform)
test_set = datasets.CIFAR10(DATA_PATH, train=False, download=download, transform=test_transform)
elif dataset == 'cifar100':
image_size = (32, 32, 3)
n_classes = 100
train_set = datasets.CIFAR100(DATA_PATH, train=True, download=download, transform=train_transform)
test_set = datasets.CIFAR100(DATA_PATH, train=False, download=download, transform=test_transform)
elif dataset == 'svhn':
assert test_only and image_size is not None
test_set = datasets.SVHN(DATA_PATH, split='test', download=download, transform=test_transform)
elif dataset == 'lsun_resize':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'LSUN_resize')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
elif dataset == 'lsun_fix':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'LSUN_fix')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
elif dataset == 'imagenet_resize':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'Imagenet_resize')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
elif dataset == 'imagenet_fix':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'Imagenet_fix')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
elif dataset == 'imagenet':
image_size = (224, 224, 3)
n_classes = 30
train_dir = os.path.join(IMAGENET_PATH, 'one_class_train')
test_dir = os.path.join(IMAGENET_PATH, 'one_class_test')
train_set = datasets.ImageFolder(train_dir, transform=train_transform)
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
elif dataset == 'stanford_dogs':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'stanford_dogs')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
elif dataset == 'cub':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'cub200')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
elif dataset == 'flowers102':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'flowers102')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
elif dataset == 'places365':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'places365')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
elif dataset == 'food_101':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'food-101', 'images')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
elif dataset == 'caltech_256':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'caltech-256')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
elif dataset == 'dtd':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'dtd', 'images')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
elif dataset == 'pets':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'pets')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
else:
raise NotImplementedError()
if test_only:
return test_set
else:
return train_set, test_set, image_size, n_classes
def get_superclass_list(dataset):
if dataset == 'CNMC':
return CNMC_SUPERCLASS
if dataset == 'CNMC_grayscale':
return CNMC_SUPERCLASS
elif dataset == 'cifar10':
return CIFAR10_SUPERCLASS
elif dataset == 'cifar100':
return CIFAR100_SUPERCLASS
elif dataset == 'imagenet':
return IMAGENET_SUPERCLASS
else:
raise NotImplementedError()
def get_subclass_dataset(dataset, classes):
if not isinstance(classes, list):
classes = [classes]
indices = []
for idx, tgt in enumerate(dataset.targets):
if tgt in classes:
indices.append(idx)
dataset = Subset(dataset, indices)
return dataset
def get_simclr_eval_transform_imagenet(sample_num, resize_factor, resize_fix):
resize_scale = (resize_factor, 1.0) # resize scaling factor
if resize_fix: # if resize_fix is True, use same scale
resize_scale = (resize_factor, resize_factor)
transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224, scale=resize_scale),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
clean_trasform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
transform = MultiDataTransformList(transform, clean_trasform, sample_num)
return transform, transform
def get_simclr_eval_transform_cnmc(sample_num, resize_factor, resize_fix, res, center_crop_size):
resize_scale = (resize_factor, 1.0) # resize scaling factor
if resize_fix: # if resize_fix is True, use same scale
resize_scale = (resize_factor, resize_factor)
transform = transforms.Compose([
transforms.Resize(res),
transforms.CenterCrop(center_crop_size),
transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
clean_trasform = transforms.Compose([
transforms.Resize(res),
transforms.CenterCrop(center_crop_size),
transforms.ToTensor(),
])
transform = MultiDataTransformList(transform, clean_trasform, sample_num)
return transform, transform

View File

@ -0,0 +1,66 @@
import os
import time
import random
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from datasets import get_subclass_dataset
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
IMAGENET_PATH = '~/data/ImageNet'
check = time.time()
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.Resize(32),
transforms.ToTensor(),
])
# remove airliner(1), ambulance(2), parking_meter(18), schooner(22) since similar class exist in CIFAR-10
class_idx_list = list(range(30))
remove_idx_list = [1, 2, 18, 22]
for remove_idx in remove_idx_list:
class_idx_list.remove(remove_idx)
set_random_seed(0)
train_dir = os.path.join(IMAGENET_PATH, 'one_class_train')
Imagenet_set = datasets.ImageFolder(train_dir, transform=transform)
Imagenet_set = get_subclass_dataset(Imagenet_set, class_idx_list)
Imagenet_dataloader = DataLoader(Imagenet_set, batch_size=100, shuffle=True, pin_memory=False)
total_test_image = None
for n, (test_image, target) in enumerate(Imagenet_dataloader):
if n == 0:
total_test_image = test_image
else:
total_test_image = torch.cat((total_test_image, test_image), dim=0).cpu()
if total_test_image.size(0) >= 10000:
break
print (f'Preprocessing time {time.time()-check}')
if not os.path.exists('./Imagenet_fix'):
os.mkdir('./Imagenet_fix')
check = time.time()
for i in range(10000):
save_image(total_test_image[i], f'Imagenet_fix/correct_resize_{i}.png')
print (f'Saving time {time.time()-check}')

View File

@ -0,0 +1,61 @@
import os
import time
import random
import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
check = time.time()
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.Resize(32),
transforms.ToTensor(),
])
set_random_seed(0)
LSUN_class_list = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen', 'living_room', 'restaurant', 'tower']
total_test_image_all_class = []
for LSUN_class in LSUN_class_list:
LSUN_set = datasets.LSUN('~/data/lsun/', classes=LSUN_class + '_train', transform=transform)
LSUN_loader = DataLoader(LSUN_set, batch_size=100, shuffle=True, pin_memory=False)
total_test_image = None
for n, (test_image, _) in enumerate(LSUN_loader):
if n == 0:
total_test_image = test_image
else:
total_test_image = torch.cat((total_test_image, test_image), dim=0).cpu()
if total_test_image.size(0) >= 1000:
break
total_test_image_all_class.append(total_test_image)
total_test_image_all_class = torch.cat(total_test_image_all_class, dim=0)
print (f'Preprocessing time {time.time()-check}')
if not os.path.exists('./LSUN_fix'):
os.mkdir('./LSUN_fix')
check = time.time()
for i in range(10000):
save_image(total_test_image_all_class[i], f'LSUN_fix/correct_resize_{i}.png')
print (f'Saving time {time.time()-check}')

View File

@ -0,0 +1,37 @@
import re
import matplotlib.pyplot as plt
PATH = r'C:\Users\feokt\PycharmProjects\CSI\CSI\logs'
def postprocess_data(log: list):
for pth in log:
loss_sim = []
loss_shift = []
with open(PATH + pth) as f:
lines = f.readlines()
for line in lines:
# line = '[2022-01-31 20:40:23.947855] [DONE] [Time 0.179] [Data 0.583] [LossC 0.000000] [LossSim 4.024234] [LossShift 0.065126]'
part = re.search('\[DONE\]', line)
if part is not None:
l_sim = re.search('(\[LossSim.[0-9]*.[0-9]*\])', line).group()
if l_sim is not None:
loss_sim.append(float(re.search('(\s[0-9].*[0-9])', l_sim).group()))
l_shift = re.search('(\[LossShift.[0-9]*.[0-9]*\])', line).group()
if l_shift is not None:
loss_shift.append(float(re.search('(\s[0-9].*[0-9])', l_shift).group()))
loss = [loss_sim[i] + loss_shift[i] for i in range(len(loss_sim))]
plt.ylabel("loss")
plt.xlabel("epoch")
plt.title("Loss over epochs")
plt.plot(list(range(1, 101)), loss)
for idx in range(len(log)):
log[idx] = log[idx][38:]
plt.legend(log)
plt.grid()
#plt.plot(list(range(1, 101)), loss_sim)
#plt.plot(list(range(1, 101)), loss_shift)
plt.show()

196
datasets/prepare_data.py Normal file
View File

@ -0,0 +1,196 @@
import csv
import os
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
import torch
def transform_image(img_in, target_dir, transformation, suffix):
"""
Transforms an image according to provided transformation.
Parameters:
img_in (path): Image to transform
target_dir (path): Destination path
transformation (callable): Transformation to be applied
suffix (str): Suffix of resulting image.
Returns:
binary_sum (str): Binary string of the sum of a and b
"""
if suffix == 'rot':
im = Image.open(img_in)
im = im.rotate(270)
tensor = transforms.ToTensor()(im)
save_image(tensor, target_dir + os.sep + suffix + '.jpg')
elif suffix == 'sobel':
im = Image.open(img_in)
tensor = transforms.ToTensor()(im)
sobel_filter = torch.tensor([[1., 2., 1.], [0., 0., 0.], [-1., -2., -1.]])
f = sobel_filter.expand(1, 3, 3, 3)
tensor = torch.conv2d(tensor, f, stride=1, padding=1 )
save_image(tensor, target_dir + os.sep + suffix + '.jpg')
elif suffix == 'noise':
im = Image.open(img_in)
tensor = transforms.ToTensor()(im)
tensor = tensor + (torch.randn(tensor.size()) * 0.2 + 0)
save_image(tensor, target_dir + os.sep + suffix + '.jpg')
elif suffix == 'cutout':
print("asd")
else:
im = Image.open(img_in)
im_trans = transformation(im)
im_trans.save(target_dir + os.sep + suffix + '.jpg')
def sort_and_rename_images(excel_path: str):
"""Renames images and sorts them according to csv."""
base_dir = excel_path.rsplit(os.sep, 1)[0]
dir_all = base_dir + os.sep + 'all'
if not os.path.isdir(dir_all):
os.mkdir(dir_all)
dir_hem = base_dir + os.sep + 'hem'
if not os.path.isdir(dir_hem):
os.mkdir(dir_hem)
with open(excel_path, mode='r') as file:
csv_file = csv.reader(file)
for lines in csv_file:
print(lines)
if lines[2] == '1':
os.rename(base_dir + os.sep + lines[1], dir_all + os.sep + lines[0])
elif lines[2] == '0':
os.rename(base_dir + os.sep + lines[1], dir_hem + os.sep + lines[0])
def drop_color_channels(source_dir, target_dir, rgb):
"""Rotates all images in in source dir."""
if rgb == 0:
suffix = "red_only"
drop_1 = 1
drop_2 = 2
elif rgb == 1:
suffix = "green_only"
drop_1 = 0
drop_2 = 2
elif rgb == 2:
suffix = "blue_only"
drop_1 = 0
drop_2 = 1
elif rgb == 3:
suffix = "no_red"
drop_1 = 0
elif rgb == 4:
suffix = "no_green"
drop_1 = 1
elif rgb == 5:
suffix = "no_blue"
drop_1 = 2
else:
suffix = ""
print("Invalid RGB-channel")
if suffix != "":
dirs = os.listdir(source_dir)
for item in dirs:
if os.path.isfile(source_dir + os.sep + item):
im = Image.open(source_dir + os.sep + item)
tensor = transforms.ToTensor()(im)
tensor[drop_1, :, :] = 0
if rgb < 3:
tensor[drop_2, :, :] = 0
save_image(tensor, target_dir + os.sep + item, 'bmp')
def rotate_images(target_dir, source_dir, rotate, theta):
"""Rotates all images in in source dir."""
dirs = os.listdir(source_dir)
for item in dirs:
if os.path.isfile(source_dir + os.sep + item):
for i in range(0, rotate):
im = Image.open(source_dir + os.sep + item)
im = im.rotate(i*theta)
tensor = transforms.ToTensor()(im)
save_image(tensor, target_dir + os.sep + str(i) + '_' + item, 'bmp')
def grayscale_image(source_dir, target_dir):
"""Grayscale transforms all images in path."""
t = transforms.Grayscale()
dirs = os.listdir(source_dir)
if not os.path.isdir(target_dir):
os.mkdir(target_dir)
for item in dirs:
if os.path.isfile(source_dir + os.sep + item):
im = Image.open(source_dir + os.sep + item).convert('RGB')
im_resize = t(im)
tensor = transforms.ToTensor()(im_resize)
padding = torch.zeros(1, tensor.shape[1], tensor.shape[2])
tensor = torch.cat((tensor, padding), 0)
im_resize.save(target_dir + os.sep + item, 'bmp')
def resize(source_dir):
"""Rotates all images in in source dir."""
t = transforms.Compose([transforms.Resize((128, 128))])
dirs = os.listdir(source_dir)
target_dir = source_dir + os.sep + 'resized'
if not os.path.isdir(target_dir):
os.mkdir(target_dir)
for item in dirs:
if os.path.isfile(source_dir + os.sep + item):
im = Image.open(source_dir + os.sep + item)
im_resize = t(im)
im_resize.save(source_dir + os.sep + 'resized' + os.sep + item, 'bmp')
def crop_image(source_dir):
"""Center Crops all images in path."""
t = transforms.CenterCrop((224, 224))
dirs = os.listdir(source_dir)
target_dir = source_dir + os.sep + 'cropped'
if not os.path.isdir(target_dir):
os.mkdir(target_dir)
for item in dirs:
if os.path.isfile(source_dir + os.sep + item):
im = Image.open(source_dir + os.sep + item)
im_resize = t(im, )
im_resize.save(source_dir + os.sep + 'cropped' + os.sep + item, 'bmp')
def mk_dirs(target_dir):
dir_0 = target_dir + r"\fold_0"
dir_1 = target_dir + r"\fold_1"
dir_2 = target_dir + r"\fold_2"
dir_3 = target_dir + r"\phase2"
dir_4 = target_dir + r"\phase3"
dir_0_all = dir_0 + r"\all"
dir_0_hem = dir_0 + r"\hem"
dir_1_all = dir_1 + r"\all"
dir_1_hem = dir_1 + r"\hem"
dir_2_all = dir_2 + r"\all"
dir_2_hem = dir_2 + r"\hem"
if not os.path.isdir(dir_0):
os.mkdir(dir_0)
if not os.path.isdir(dir_1):
os.mkdir(dir_1)
if not os.path.isdir(dir_2):
os.mkdir(dir_2)
if not os.path.isdir(dir_3):
os.mkdir(dir_3)
if not os.path.isdir(dir_4):
os.mkdir(dir_4)
if not os.path.isdir(dir_0_all):
os.mkdir(dir_0_all)
if not os.path.isdir(dir_0_hem):
os.mkdir(dir_0_hem)
if not os.path.isdir(dir_1_all):
os.mkdir(dir_1_all)
if not os.path.isdir(dir_1_hem):
os.mkdir(dir_1_hem)
if not os.path.isdir(dir_2_all):
os.mkdir(dir_2_all)
if not os.path.isdir(dir_2_hem):
os.mkdir(dir_2_hem)
return dir_0_all, dir_0_hem, dir_1_all, dir_1_hem, dir_2_all, dir_2_hem, dir_3, dir_4

4691
eval.ipynb Normal file

File diff suppressed because it is too large Load Diff

57
eval.py Normal file
View File

@ -0,0 +1,57 @@
from common.eval import *
def main():
model.eval()
if P.mode == 'test_acc':
from evals import test_classifier
with torch.no_grad():
error = test_classifier(P, model, test_loader, 0, logger=None)
elif P.mode == 'test_marginalized_acc':
from evals import test_classifier
with torch.no_grad():
error = test_classifier(P, model, test_loader, 0, marginal=True, logger=None)
elif P.mode in ['ood', 'ood_pre']:
if P.mode == 'ood':
from evals import eval_ood_detection
else:
from evals.ood_pre import eval_ood_detection
with torch.no_grad():
auroc_dict = eval_ood_detection(P, model, test_loader, ood_test_loader, P.ood_score,
train_loader=train_loader, simclr_aug=simclr_aug)
if P.one_class_idx is not None:
mean_dict = dict()
for ood_score in P.ood_score:
mean = 0
for ood in auroc_dict.keys():
mean += auroc_dict[ood][ood_score]
mean_dict[ood_score] = mean / len(auroc_dict.keys())
auroc_dict['one_class_mean'] = mean_dict
bests = []
for ood in auroc_dict.keys():
message = ''
best_auroc = 0
for ood_score, auroc in auroc_dict[ood].items():
message += '[%s %s %.4f] ' % (ood, ood_score, auroc)
if auroc > best_auroc:
best_auroc = auroc
message += '[%s %s %.4f] ' % (ood, 'best', best_auroc)
if P.print_score:
print(message)
bests.append(best_auroc)
bests = map('{:.4f}'.format, bests)
print('\t'.join(bests))
else:
raise NotImplementedError()
if __name__ == '__main__':
main()

1
evals/__init__.py Normal file
View File

@ -0,0 +1 @@
from evals.evals import test_classifier, eval_ood_detection

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

201
evals/evals.py Normal file
View File

@ -0,0 +1,201 @@
import time
import itertools
import diffdist.functional as distops
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
import models.transform_layers as TL
from utils.temperature_scaling import _ECELoss
from utils.utils import AverageMeter, set_random_seed, normalize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ece_criterion = _ECELoss().to(device)
def error_k(output, target, ks=(1,)):
"""Computes the precision@k for the specified values of k"""
max_k = max(ks)
batch_size = target.size(0)
_, pred = output.topk(max_k, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
results = []
for k in ks:
correct_k = correct[:k].view(-1).float().sum(0)
results.append(100.0 - correct_k.mul_(100.0 / batch_size))
return results
def test_classifier(P, model, loader, steps, marginal=False, logger=None):
error_top1 = AverageMeter()
error_calibration = AverageMeter()
if logger is None:
log_ = print
else:
log_ = logger.log
# Switch to evaluate mode
mode = model.training
model.eval()
for n, (images, labels) in enumerate(loader):
batch_size = images.size(0)
images, labels = images.to(device), labels.to(device)
if marginal:
outputs = 0
for i in range(4):
rot_images = torch.rot90(images, i, (2, 3))
_, outputs_aux = model(rot_images, joint=True)
outputs += outputs_aux['joint'][:, P.n_classes * i: P.n_classes * (i + 1)] / 4.
else:
outputs = model(images)
top1, = error_k(outputs.data, labels, ks=(1,))
error_top1.update(top1.item(), batch_size)
ece = ece_criterion(outputs, labels) * 100
error_calibration.update(ece.item(), batch_size)
if n % 100 == 0:
log_('[Test %3d] [Test@1 %.3f] [ECE %.3f]' %
(n, error_top1.value, error_calibration.value))
log_(' * [Error@1 %.3f] [ECE %.3f]' %
(error_top1.average, error_calibration.average))
if logger is not None:
logger.scalar_summary('eval/clean_error', error_top1.average, steps)
logger.scalar_summary('eval/ece', error_calibration.average, steps)
model.train(mode)
return error_top1.average
def eval_ood_detection(P, model, id_loader, ood_loaders, ood_scores, train_loader=None, simclr_aug=None):
auroc_dict = dict()
for ood in ood_loaders.keys():
auroc_dict[ood] = dict()
for ood_score in ood_scores:
# compute scores for ID and OOD samples
score_func = get_ood_score_func(P, model, ood_score, simclr_aug=simclr_aug)
save_path = f'plot/score_in_{P.dataset}_{ood_score}'
if P.one_class_idx is not None:
save_path += f'_{P.one_class_idx}'
scores_id = get_scores(id_loader, score_func)
if P.save_score:
np.save(f'{save_path}.npy', scores_id)
for ood, ood_loader in ood_loaders.items():
if ood == 'interp':
scores_ood = get_scores_interp(id_loader, score_func)
auroc_dict['interp'][ood_score] = get_auroc(scores_id, scores_ood)
else:
scores_ood = get_scores(ood_loader, score_func)
auroc_dict[ood][ood_score] = get_auroc(scores_id, scores_ood)
if P.save_score:
np.save(f'{save_path}_out_{ood}.npy', scores_ood)
return auroc_dict
def get_ood_score_func(P, model, ood_score, simclr_aug=None):
def score_func(x):
return compute_ood_score(P, model, ood_score, x, simclr_aug=simclr_aug)
return score_func
def get_scores(loader, score_func):
scores = []
for i, (x, _) in enumerate(loader):
s = score_func(x.to(device))
assert s.dim() == 1 and s.size(0) == x.size(0)
scores.append(s.detach().cpu().numpy())
return np.concatenate(scores)
def get_scores_interp(loader, score_func):
scores = []
for i, (x, _) in enumerate(loader):
x_interp = (x + last) / 2 if i > 0 else x # omit the first batch, assume batch sizes are equal
last = x # save the last batch
s = score_func(x_interp.to(device))
assert s.dim() == 1 and s.size(0) == x.size(0)
scores.append(s.detach().cpu().numpy())
return np.concatenate(scores)
def get_auroc(scores_id, scores_ood):
scores = np.concatenate([scores_id, scores_ood])
labels = np.concatenate([np.ones_like(scores_id), np.zeros_like(scores_ood)])
return roc_auc_score(labels, scores)
def compute_ood_score(P, model, ood_score, x, simclr_aug=None):
model.eval()
if ood_score == 'clean_norm':
_, output_aux = model(x, penultimate=True, simclr=True)
score = output_aux[P.ood_layer].norm(dim=1)
return score
elif ood_score == 'similar':
assert simclr_aug is not None # require custom simclr augmentation
sample_num = 2 # fast evaluation
feats = get_features(model, simclr_aug, x, layer=P.ood_layer, sample_num=sample_num)
feats_avg = sum(feats) / len(feats)
scores = []
for seed in range(sample_num):
sim = torch.cosine_similarity(feats[seed], feats_avg)
scores.append(sim)
return sum(scores) / len(scores)
elif ood_score == 'baseline':
outputs, outputs_aux = model(x, penultimate=True)
scores = F.softmax(outputs, dim=1).max(dim=1)[0]
return scores
elif ood_score == 'baseline_marginalized':
total_outputs = 0
for i in range(4):
x_rot = torch.rot90(x, i, (2, 3))
outputs, outputs_aux = model(x_rot, penultimate=True, joint=True)
total_outputs += outputs_aux['joint'][:, P.n_classes * i:P.n_classes * (i + 1)]
scores = F.softmax(total_outputs / 4., dim=1).max(dim=1)[0]
return scores
else:
raise NotImplementedError()
def get_features(model, simclr_aug, x, layer='simclr', sample_num=1):
model.eval()
feats = []
for seed in range(sample_num):
set_random_seed(seed)
x_t = simclr_aug(x)
with torch.no_grad():
_, output_aux = model(x_t, penultimate=True, simclr=True, shift=True)
feats.append(output_aux[layer])
return feats

242
evals/ood_pre.py Normal file
View File

@ -0,0 +1,242 @@
import os
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import models.transform_layers as TL
from utils.utils import set_random_seed, normalize
from evals.evals import get_auroc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hflip = TL.HorizontalFlipLayer().to(device)
def eval_ood_detection(P, model, id_loader, ood_loaders, ood_scores, train_loader=None, simclr_aug=None):
auroc_dict = dict()
for ood in ood_loaders.keys():
auroc_dict[ood] = dict()
assert len(ood_scores) == 1 # assume single ood_score for simplicity
ood_score = ood_scores[0]
base_path = os.path.split(P.load_path)[0] # checkpoint directory
prefix = f'{P.ood_samples}'
if P.resize_fix:
prefix += f'_resize_fix_{P.resize_factor}'
else:
prefix += f'_resize_range_{P.resize_factor}'
prefix = os.path.join(base_path, f'feats_{prefix}')
kwargs = {
'simclr_aug': simclr_aug,
'sample_num': P.ood_samples,
'layers': P.ood_layer,
}
print('Pre-compute global statistics...')
feats_train = get_features(P, f'{P.dataset}_train', model, train_loader, prefix=prefix, **kwargs) # (M, T, d)
P.axis = []
for f in feats_train['simclr'].chunk(P.K_shift, dim=1):
axis = f.mean(dim=1) # (M, d)
P.axis.append(normalize(axis, dim=1).to(device))
print('axis size: ' + ' '.join(map(lambda x: str(len(x)), P.axis)))
f_sim = [f.mean(dim=1) for f in feats_train['simclr'].chunk(P.K_shift, dim=1)] # list of (M, d)
f_shi = [f.mean(dim=1) for f in feats_train['shift'].chunk(P.K_shift, dim=1)] # list of (M, 4)
weight_sim = []
weight_shi = []
for shi in range(P.K_shift):
sim_norm = f_sim[shi].norm(dim=1) # (M)
shi_mean = f_shi[shi][:, shi] # (M)
weight_sim.append(1 / sim_norm.mean().item())
weight_shi.append(1 / shi_mean.mean().item())
if ood_score == 'simclr':
P.weight_sim = [1]
P.weight_shi = [0]
elif ood_score == 'CSI':
P.weight_sim = weight_sim
P.weight_shi = weight_shi
else:
raise ValueError()
print(f'weight_sim:\t' + '\t'.join(map('{:.4f}'.format, P.weight_sim)))
print(f'weight_shi:\t' + '\t'.join(map('{:.4f}'.format, P.weight_shi)))
print('Pre-compute features...')
feats_id = get_features(P, P.dataset, model, id_loader, prefix=prefix, **kwargs) # (N, T, d)
feats_ood = dict()
for ood, ood_loader in ood_loaders.items():
if ood == 'interp':
feats_ood[ood] = get_features(P, ood, model, id_loader, interp=True, prefix=prefix, **kwargs)
else:
feats_ood[ood] = get_features(P, ood, model, ood_loader, prefix=prefix, **kwargs)
print(f'Compute OOD scores... (score: {ood_score})')
scores_id = get_scores(P, feats_id, ood_score).numpy()
scores_ood = dict()
if P.one_class_idx is not None:
one_class_score = []
for ood, feats in feats_ood.items():
scores_ood[ood] = get_scores(P, feats, ood_score).numpy()
auroc_dict[ood][ood_score] = get_auroc(scores_id, scores_ood[ood])
if P.one_class_idx is not None:
one_class_score.append(scores_ood[ood])
if P.one_class_idx is not None:
one_class_score = np.concatenate(one_class_score)
one_class_total = get_auroc(scores_id, one_class_score)
print(f'One_class_real_mean: {one_class_total}')
if P.print_score:
print_score(P.dataset, scores_id)
for ood, scores in scores_ood.items():
print_score(ood, scores)
return auroc_dict
def get_scores(P, feats_dict, ood_score):
# convert to gpu tensor
feats_sim = feats_dict['simclr'].to(device)
feats_shi = feats_dict['shift'].to(device)
N = feats_sim.size(0)
# compute scores
scores = []
for f_sim, f_shi in zip(feats_sim, feats_shi):
f_sim = [f.mean(dim=0, keepdim=True) for f in f_sim.chunk(P.K_shift)] # list of (1, d)
f_shi = [f.mean(dim=0, keepdim=True) for f in f_shi.chunk(P.K_shift)] # list of (1, 4)
score = 0
for shi in range(P.K_shift):
score += (f_sim[shi] * P.axis[shi]).sum(dim=1).max().item() * P.weight_sim[shi]
score += f_shi[shi][:, shi].item() * P.weight_shi[shi]
score = score / P.K_shift
scores.append(score)
scores = torch.tensor(scores)
assert scores.dim() == 1 and scores.size(0) == N # (N)
return scores.cpu()
def get_features(P, data_name, model, loader, interp=False, prefix='',
simclr_aug=None, sample_num=1, layers=('simclr', 'shift')):
if not isinstance(layers, (list, tuple)):
layers = [layers]
# load pre-computed features if exists
feats_dict = dict()
# for layer in layers:
# path = prefix + f'_{data_name}_{layer}.pth'
# if os.path.exists(path):
# feats_dict[layer] = torch.load(path)
# pre-compute features and save to the path
left = [layer for layer in layers if layer not in feats_dict.keys()]
if len(left) > 0:
_feats_dict = _get_features(P, model, loader, interp, (P.dataset == 'imagenet' or
P.dataset == 'CNMC' or
P.dataset == 'CNMC_grayscale'),
simclr_aug, sample_num, layers=left)
for layer, feats in _feats_dict.items():
path = prefix + f'_{data_name}_{layer}.pth'
torch.save(_feats_dict[layer], path)
feats_dict[layer] = feats # update value
return feats_dict
def _get_features(P, model, loader, interp=False, imagenet=False, simclr_aug=None,
sample_num=1, layers=('simclr', 'shift')):
if not isinstance(layers, (list, tuple)):
layers = [layers]
# check if arguments are valid
assert simclr_aug is not None
if imagenet is True: # assume batch_size = 1 for ImageNet
sample_num = 1
# compute features in full dataset
model.eval()
feats_all = {layer: [] for layer in layers} # initialize: empty list
for i, (x, _) in enumerate(loader):
if interp:
x_interp = (x + last) / 2 if i > 0 else x # omit the first batch, assume batch sizes are equal
last = x # save the last batch
x = x_interp # use interp as current batch
if imagenet is True:
x = torch.cat(x[0], dim=0) # augmented list of x
x = x.to(device) # gpu tensor
# compute features in one batch
feats_batch = {layer: [] for layer in layers} # initialize: empty list
for seed in range(sample_num):
set_random_seed(seed)
if P.K_shift > 1:
x_t = torch.cat([P.shift_trans(hflip(x), k) for k in range(P.K_shift)])
else:
x_t = x # No shifting: SimCLR
x_t = simclr_aug(x_t)
# compute augmented features
with torch.no_grad():
kwargs = {layer: True for layer in layers} # only forward selected layers
_, output_aux = model(x_t, **kwargs)
# add features in one batch
for layer in layers:
feats = output_aux[layer].cpu()
if imagenet is False:
feats_batch[layer] += feats.chunk(P.K_shift)
else:
feats_batch[layer] += [feats] # (B, d) cpu tensor
# concatenate features in one batch
for key, val in feats_batch.items():
if imagenet:
feats_batch[key] = torch.stack(val, dim=0) # (B, T, d)
else:
feats_batch[key] = torch.stack(val, dim=1) # (B, T, d)
# add features in full dataset
for layer in layers:
feats_all[layer] += [feats_batch[layer]]
# concatenate features in full dataset
for key, val in feats_all.items():
feats_all[key] = torch.cat(val, dim=0) # (N, T, d)
# reshape order
if imagenet is False:
# Convert [1,2,3,4, 1,2,3,4] -> [1,1, 2,2, 3,3, 4,4]
for key, val in feats_all.items():
N, T, d = val.size() # T = K * T'
val = val.view(N, -1, P.K_shift, d) # (N, T', K, d)
val = val.transpose(2, 1) # (N, 4, T', d)
val = val.reshape(N, T, d) # (N, T, d)
feats_all[key] = val
return feats_all
def print_score(data_name, scores):
quantile = np.quantile(scores, np.arange(0, 1.1, 0.1))
print('{:18s} '.format(data_name) +
'{:.4f} +- {:.4f} '.format(np.mean(scores), np.std(scores)) +
' '.join(['q{:d}: {:.4f}'.format(i * 10, quantile[i]) for i in range(11)]))

BIN
figures/CSI_teaser.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 400 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 MiB

37
main.py Normal file
View File

@ -0,0 +1,37 @@
from sys import argv
from os import system
from datasets.prepare_data import prep, resize
import torch
import os
from datasets.postprocess_data import postprocess_data
DATA_BASE_DIR = r'/home/feoktistovar67431/CSI/CSI_local/main.py'
BASE_DIR = '/home/feoktistovar67431/CSI/CSI_local/'
def main():
for argument in argv:
if argument == '--proc_step':
proc_step = argv[argv.index(argument)+1]
if proc_step == 'eval':
system("eval.py "+' '.join(argv[1:]))
if proc_step == 'train':
system(BASE_DIR + os.sep + "eval.py " + ' '.join(argv[1:]))
if proc_step == 'plot':
plot_data()
elif proc_step == 'post_proc':
postprocess_data(
[
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm4_one_class_0\log.txt',
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm4_one_class_0_64px\log.txt',
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm16_one_class_0_32px\log.txt',
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm_one_class_0_64px_batch64\log.txt',
r'\CNMC_resnet18_unsup_simclr_CSI_shift_rotation_one_class_0\log.txt',
r"\CNMC_resnet18_unsup_simclr_CSI_shift_gauss_one_class_0_32px\log.txt"
# r'\cifar10_resnet18_unsup_simclr_CSI_shift_rotation_one_class_1\log.txt'
]
)
if __name__ == '__main__':
main()

0
models/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

48
models/base_model.py Normal file
View File

@ -0,0 +1,48 @@
from abc import *
import torch.nn as nn
class BaseModel(nn.Module, metaclass=ABCMeta):
def __init__(self, last_dim, num_classes=10, simclr_dim=128):
super(BaseModel, self).__init__()
self.linear = nn.Linear(last_dim, num_classes)
self.simclr_layer = nn.Sequential(
nn.Linear(last_dim, last_dim),
nn.ReLU(),
nn.Linear(last_dim, simclr_dim),
)
self.shift_cls_layer = nn.Linear(last_dim, 2)
self.joint_distribution_layer = nn.Linear(last_dim, 4 * num_classes)
@abstractmethod
def penultimate(self, inputs, all_features=False):
pass
def forward(self, inputs, penultimate=False, simclr=False, shift=False, joint=False):
_aux = {}
_return_aux = False
features = self.penultimate(inputs)
output = self.linear(features)
if penultimate:
_return_aux = True
_aux['penultimate'] = features
if simclr:
_return_aux = True
_aux['simclr'] = self.simclr_layer(features)
if shift:
_return_aux = True
_aux['shift'] = self.shift_cls_layer(features)
if joint:
_return_aux = True
_aux['joint'] = self.joint_distribution_layer(features)
if _return_aux:
return output, _aux
return output

135
models/classifier.py Normal file
View File

@ -0,0 +1,135 @@
import torch.nn as nn
from models.resnet import ResNet18, ResNet34, ResNet50
from models.resnet_imagenet import resnet18, resnet50
import models.transform_layers as TL
from torchvision import transforms
def get_simclr_augmentation(P, image_size):
"""
Creates positive data for training.
:param P: parsed arguments
:param image_size: size of image
:return: transformation
"""
# parameter for resizecrop
resize_scale = (P.resize_factor, 1.0) # resize scaling factor
if P.resize_fix: # if resize_fix is True, use same scale
resize_scale = (P.resize_factor, P.resize_factor)
# Align augmentation
s = P.color_distort
color_jitter = TL.ColorJitterLayer(brightness=s*0.8, contrast=s*0.8, saturation=s*0.8, hue=s*0.2, p=0.8)
color_gray = TL.RandomColorGrayLayer(p=0.2)
resize_crop = TL.RandomResizedCropLayer(scale=resize_scale, size=(image_size[0], image_size[1]))
#v_flip = transforms.RandomVerticalFlip()
#h_flip = transforms.RandomHorizontalFlip()
rand_aff = transforms.RandomAffine(degrees=360, translate=(0.2, 0.2))
# Transform define #
if P.dataset == 'imagenet': # Using RandomResizedCrop at PIL transform
transform = nn.Sequential(
color_jitter,
color_gray,
)
elif P.dataset == 'CNMC':
transform = nn.Sequential(
color_jitter,
color_gray,
resize_crop,
)
else:
transform = nn.Sequential(
color_jitter,
color_gray,
resize_crop,
)
return transform
def get_shift_module(P, eval=False):
"""
Creates shift transformation (negative).
:param P: parsed arguments
:param eval: whether it is an evaluation step or not
:return: transformation
"""
if P.shift_trans_type == 'rotation':
shift_transform = TL.Rotation()
K_shift = 4
elif P.shift_trans_type == 'cutperm':
shift_transform = TL.CutPerm()
K_shift = 4
elif P.shift_trans_type == 'noise':
shift_transform = TL.GaussNoise(mean=P.noise_mean, std=P.noise_std)
K_shift = 4
elif P.shift_trans_type == 'randpers':
shift_transform = TL.RandPers(distortion_scale=P.distortion_scale, p=1)
K_shift = 4
elif P.shift_trans_type == 'sharp':
shift_transform = TL.RandomAdjustSharpness(sharpness_factor=P.sharpness_factor, p=1)
K_shift = 4
elif P.shift_trans_type == 'blur':
kernel_size = int(int(P.res.replace('px', ''))*0.1)
if kernel_size%2 == 0:
kernel_size+=1
sigma = (0.1, float(P.blur_sigma))
shift_transform = TL.GaussBlur(kernel_size=kernel_size, sigma=sigma)
K_shift = 4
elif P.shift_trans_type == 'blur_randpers':
kernel_size = int(P.res.replace('px', '')) * 0.1
sigma = (0.1, float(P.blur_sigma))
shift_transform = TL.BlurRandpers(kernel_size=kernel_size, sigma=sigma, distortion_scale=P.distortion_scale, p=1)
K_shift = 4
elif P.shift_trans_type == 'blur_sharp':
kernel_size = int(P.res.replace('px', '')) * 0.1
sigma = (0.1, float(P.blur_sigma))
shift_transform = TL.BlurSharpness(kernel_size=kernel_size, sigma=sigma, sharpness_factor=P.sharpness_factor, p=1)
K_shift = 4
elif P.shift_trans_type == 'randpers_sharp':
shift_transform = TL.RandpersSharpness(distortion_scale=P.distortion_scale, p=1, sharpness_factor=P.sharpness_factor)
K_shift = 4
elif P.shift_trans_type == 'blur_randpers_sharp':
kernel_size = int(P.res.replace('px', '')) * 0.1
sigma = (0.1, float(P.blur_sigma))
shift_transform = TL.BlurRandpersSharpness(kernel_size=kernel_size, sigma=sigma, distortion_scale=P.distortion_scale, p=1, sharpness_factor=P.sharpness_factor)
K_shift = 4
else:
shift_transform = nn.Identity()
K_shift = 1
if not eval and not ('sup' in P.mode):
assert P.batch_size == int(128/K_shift)
return shift_transform, K_shift
def get_shift_classifer(model, K_shift):
model.shift_cls_layer = nn.Linear(model.last_dim, K_shift)
return model
def get_classifier(mode, n_classes=10):
if mode == 'resnet18':
classifier = ResNet18(num_classes=n_classes)
elif mode == 'resnet34':
classifier = ResNet34(num_classes=n_classes)
elif mode == 'resnet50':
classifier = ResNet50(num_classes=n_classes)
elif mode == 'resnet18_imagenet':
classifier = resnet18(num_classes=n_classes)
elif mode == 'resnet50_imagenet':
classifier = resnet50(num_classes=n_classes)
else:
raise NotImplementedError()
return classifier

189
models/resnet.py Normal file
View File

@ -0,0 +1,189 @@
'''ResNet in PyTorch.
BasicBlock and Bottleneck module is from the original ResNet paper:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
PreActBlock and PreActBottleneck module is from the later paper:
[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Identity Mappings in Deep Residual Networks. arXiv:1603.05027
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.base_model import BaseModel
from models.transform_layers import NormalizeLayer
from torch.nn.utils import spectral_norm
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_planes, planes, stride)
self.conv2 = conv3x3(planes, planes)
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class PreActBlock(nn.Module):
'''Pre-activation version of the BasicBlock.'''
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(PreActBlock, self).__init__()
self.conv1 = conv3x3(in_planes, planes, stride)
self.conv2 = conv3x3(planes, planes)
self.bn1 = nn.BatchNorm2d(in_planes)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
)
def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out)
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out += shortcut
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class PreActBottleneck(nn.Module):
'''Pre-activation version of the original Bottleneck module.'''
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(PreActBottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(in_planes)
self.bn2 = nn.BatchNorm2d(planes)
self.bn3 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
)
def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out)
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out = self.conv3(F.relu(self.bn3(out)))
out += shortcut
return out
class ResNet(BaseModel):
def __init__(self, block, num_blocks, num_classes=10):
last_dim = 512 * block.expansion
super(ResNet, self).__init__(last_dim, num_classes)
self.in_planes = 64
self.last_dim = last_dim
self.normalize = NormalizeLayer()
self.conv1 = conv3x3(3, 64)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def penultimate(self, x, all_features=False):
out_list = []
out = self.normalize(x)
out = self.conv1(out)
out = self.bn1(out)
out = F.relu(out)
out_list.append(out)
out = self.layer1(out)
out_list.append(out)
out = self.layer2(out)
out_list.append(out)
out = self.layer3(out)
out_list.append(out)
out = self.layer4(out)
out_list.append(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
if all_features:
return out, out_list
else:
return out
def ResNet18(num_classes):
return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes)
def ResNet34(num_classes):
return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes)
def ResNet50(num_classes):
return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes)

231
models/resnet_imagenet.py Normal file
View File

@ -0,0 +1,231 @@
import torch
import torch.nn as nn
from models.base_model import BaseModel
from models.transform_layers import NormalizeLayer
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(BaseModel):
def __init__(self, block, layers, num_classes=10,
zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
last_dim = 512 * block.expansion
super(ResNet, self).__init__(last_dim, num_classes)
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.normalize = NormalizeLayer()
self.last_dim = 512 * block.expansion
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def penultimate(self, x, all_features=False):
# See note [TorchScript super()]
out_list = []
x = self.normalize(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
out_list.append(x)
x = self.layer1(x)
out_list.append(x)
x = self.layer2(x)
out_list.append(x)
x = self.layer3(x)
out_list.append(x)
x = self.layer4(x)
out_list.append(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
if all_features:
return x, out_list
else:
return x
def _resnet(arch, block, layers, **kwargs):
model = ResNet(block, layers, **kwargs)
return model
def resnet18(**kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], **kwargs)
def resnet50(**kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], **kwargs)

643
models/transform_layers.py Normal file
View File

@ -0,0 +1,643 @@
import math
import numbers
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torchvision import transforms
if torch.__version__ >= '1.4.0':
kwargs = {'align_corners': False}
else:
kwargs = {}
def rgb2hsv(rgb):
"""Convert a 4-d RGB tensor to the HSV counterpart.
Here, we compute hue using atan2() based on the definition in [1],
instead of using the common lookup table approach as in [2, 3].
Those values agree when the angle is a multiple of 30°,
otherwise they may differ at most ~1.2°.
References
[1] https://en.wikipedia.org/wiki/Hue
[2] https://www.rapidtables.com/convert/color/rgb-to-hsv.html
[3] https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L212
"""
r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :]
Cmax = rgb.max(1)[0]
Cmin = rgb.min(1)[0]
delta = Cmax - Cmin
hue = torch.atan2(math.sqrt(3) * (g - b), 2 * r - g - b)
hue = (hue % (2 * math.pi)) / (2 * math.pi)
saturate = delta / Cmax
value = Cmax
hsv = torch.stack([hue, saturate, value], dim=1)
hsv[~torch.isfinite(hsv)] = 0.
return hsv
def hsv2rgb(hsv):
"""Convert a 4-d HSV tensor to the RGB counterpart.
>>> %timeit hsv2rgb(hsv)
2.37 ms ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit rgb2hsv_fast(rgb)
298 µs ± 542 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
>>> torch.allclose(hsv2rgb(hsv), hsv2rgb_fast(hsv), atol=1e-6)
True
References
[1] https://en.wikipedia.org/wiki/HSL_and_HSV#HSV_to_RGB_alternative
"""
h, s, v = hsv[:, [0]], hsv[:, [1]], hsv[:, [2]]
c = v * s
n = hsv.new_tensor([5, 3, 1]).view(3, 1, 1)
k = (n + h * 6) % 6
t = torch.min(k, 4 - k)
t = torch.clamp(t, 0, 1)
return v - c * t
class RandomResizedCropLayer(nn.Module):
def __init__(self, size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.)):
'''
Inception Crop
size (tuple): size of fowarding image (C, W, H)
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
'''
super(RandomResizedCropLayer, self).__init__()
_eye = torch.eye(2, 3)
self.size = size
self.register_buffer('_eye', _eye)
self.scale = scale
self.ratio = ratio
def forward(self, inputs, whbias=None):
_device = inputs.device
N = inputs.size(0)
_theta = self._eye.repeat(N, 1, 1)
if whbias is None:
whbias = self._sample_latent(inputs)
_theta[:, 0, 0] = whbias[:, 0]
_theta[:, 1, 1] = whbias[:, 1]
_theta[:, 0, 2] = whbias[:, 2]
_theta[:, 1, 2] = whbias[:, 3]
grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device)
output = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs)
if self.size is not None:
output = F.adaptive_avg_pool2d(output, self.size)
# output = F.adaptive_avg_pool2d(output, self.size)
# output = F.adaptive_avg_pool2d(output, (self.size[0], self.size[1]))
return output
def _clamp(self, whbias):
w = whbias[:, 0]
h = whbias[:, 1]
w_bias = whbias[:, 2]
h_bias = whbias[:, 3]
# Clamp with scale
w = torch.clamp(w, *self.scale)
h = torch.clamp(h, *self.scale)
# Clamp with ratio
w = self.ratio[0] * h + torch.relu(w - self.ratio[0] * h)
w = self.ratio[1] * h - torch.relu(self.ratio[1] * h - w)
# Clamp with bias range: w_bias \in (w - 1, 1 - w), h_bias \in (h - 1, 1 - h)
w_bias = w - 1 + torch.relu(w_bias - w + 1)
w_bias = 1 - w - torch.relu(1 - w - w_bias)
h_bias = h - 1 + torch.relu(h_bias - h + 1)
h_bias = 1 - h - torch.relu(1 - h - h_bias)
whbias = torch.stack([w, h, w_bias, h_bias], dim=0).t()
return whbias
def _sample_latent(self, inputs):
_device = inputs.device
N, _, width, height = inputs.shape
# N * 10 trial
area = width * height
target_area = np.random.uniform(*self.scale, N * 10) * area
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
aspect_ratio = np.exp(np.random.uniform(*log_ratio, N * 10))
# If doesn't satisfy ratio condition, then do central crop
w = np.round(np.sqrt(target_area * aspect_ratio))
h = np.round(np.sqrt(target_area / aspect_ratio))
cond = (0 < w) * (w <= width) * (0 < h) * (h <= height)
w = w[cond]
h = h[cond]
cond_len = w.shape[0]
if cond_len >= N:
w = w[:N]
h = h[:N]
else:
w = np.concatenate([w, np.ones(N - cond_len) * width])
h = np.concatenate([h, np.ones(N - cond_len) * height])
w_bias = np.random.randint(w - width, width - w + 1) / width
h_bias = np.random.randint(h - height, height - h + 1) / height
w = w / width
h = h / height
whbias = np.column_stack([w, h, w_bias, h_bias])
whbias = torch.tensor(whbias, device=_device)
return whbias
class HorizontalFlipRandomCrop(nn.Module):
def __init__(self, max_range):
super(HorizontalFlipRandomCrop, self).__init__()
self.max_range = max_range
_eye = torch.eye(2, 3)
self.register_buffer('_eye', _eye)
def forward(self, input, sign=None, bias=None, rotation=None):
_device = input.device
N = input.size(0)
_theta = self._eye.repeat(N, 1, 1)
if sign is None:
sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1
if bias is None:
bias = torch.empty((N, 2), device=_device).uniform_(-self.max_range, self.max_range)
_theta[:, 0, 0] = sign
_theta[:, :, 2] = bias
if rotation is not None:
_theta[:, 0:2, 0:2] = rotation
grid = F.affine_grid(_theta, input.size(), **kwargs).to(_device)
output = F.grid_sample(input, grid, padding_mode='reflection', **kwargs)
return output
def _sample_latent(self, N, device=None):
sign = torch.bernoulli(torch.ones(N, device=device) * 0.5) * 2 - 1
bias = torch.empty((N, 2), device=device).uniform_(-self.max_range, self.max_range)
return sign, bias
class Rotation(nn.Module):
def __init__(self, max_range = 4):
super(Rotation, self).__init__()
self.max_range = max_range
self.prob = 0.5
def forward(self, input, aug_index=None):
_device = input.device
_, _, H, W = input.size()
if aug_index is None:
aug_index = np.random.randint(4)
output = torch.rot90(input, aug_index, (2, 3))
_prob = input.new_full((input.size(0),), self.prob)
_mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
output = _mask * input + (1-_mask) * output
else:
aug_index = aug_index % self.max_range
output = torch.rot90(input, aug_index, (2, 3))
return output
class RandomAdjustSharpness(nn.Module):
def __init__(self, sharpness_factor=0.5, p=0.5):
super(RandomAdjustSharpness, self).__init__()
self.sharpness_factor = sharpness_factor
self.prob = p
def forward(self, input, aug_index=None):
_device = input.device
_, _, H, W = input.size()
if aug_index == 0:
output = input
else:
output = transforms.RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.prob)(input)
return output
class RandPers(nn.Module):
def __init__(self, distortion_scale=0.5, p=0.5):
super(RandPers, self).__init__()
self.distortion_scale = distortion_scale
self.prob = p
def forward(self, input, aug_index=None):
_device = input.device
_, _, H, W = input.size()
if aug_index == 0:
output = input
else:
output = transforms.RandomPerspective(distortion_scale=self.distortion_scale, p=self.prob)(input)
return output
class GaussBlur(nn.Module):
def __init__(self, max_range = 4, kernel_size=3, sigma=(0.1, 2.0)):
super(GaussBlur, self).__init__()
self.max_range = max_range
self.prob = 0.5
self.sigma = sigma
self.kernel_size = kernel_size
def forward(self, input, aug_index=None):
_device = input.device
_, _, H, W = input.size()
if aug_index is None:
aug_index = np.random.randint(4)
output = transforms.GaussianBlur(kernel_size=13, sigma=abs(aug_index)+1)(input)
_prob = input.new_full((input.size(0),), self.prob)
_mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
output = _mask * input + (1-_mask) * output
else:
if aug_index == 0:
output = input
else:
output = transforms.GaussianBlur(kernel_size=self.kernel_size, sigma=self.sigma)(input)
return output
class GaussNoise(nn.Module):
def __init__(self, mean = 0, std = 1):
super(GaussNoise, self).__init__()
self.mean = mean
self.std = std
def forward(self, input, aug_index=None):
_device = input.device
_, _, H, W = input.size()
if aug_index == 0:
output = input
else:
output = input + (torch.randn(input.size()) * self.std + self.mean).to(_device)
return output
class BlurRandpers(nn.Module):
def __init__(self, max_range=2, kernel_size=3, sigma=(10, 20), distortion_scale=0.6, p=1):
super(BlurRandpers, self).__init__()
self.max_range = max_range
self.sigma = sigma
self.kernel_size = kernel_size
self.distortion_scale = distortion_scale
self.p = p
self.gauss = GaussBlur(kernel_size=self.kernel_size, sigma=self.sigma)
self.randpers = RandPers(distortion_scale=self.distortion_scale, p=self.p)
def forward(self, input, aug_index=None):
output = self.gauss.forward(input=input, aug_index=aug_index)
output = self.randpers.forward(input=output, aug_index=aug_index)
return output
class BlurSharpness(nn.Module):
def __init__(self, max_range=2, kernel_size=3, sigma=(10, 20), sharpness_factor=0.6, p=1):
super(BlurSharpness, self).__init__()
self.max_range = max_range
self.sigma = sigma
self.kernel_size = kernel_size
self.sharpness_factor = sharpness_factor
self.p = p
self.gauss = GaussBlur(kernel_size=self.kernel_size, sigma=self.sigma)
self.sharp = RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.p)
def forward(self, input, aug_index=None):
output = self.gauss.forward(input=input, aug_index=aug_index)
output = self.sharp.forward(input=output, aug_index=aug_index)
return output
class RandpersSharpness(nn.Module):
def __init__(self, max_range=2, distortion_scale=0.6, p=1, sharpness_factor=0.6):
super(RandpersSharpness, self).__init__()
self.max_range = max_range
self.distortion_scale = distortion_scale
self.p = p
self.sharpness_factor = sharpness_factor
self.randpers = RandPers(distortion_scale=self.distortion_scale, p=self.p)
self.sharp = RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.p)
def forward(self, input, aug_index=None):
output = self.randpers.forward(input=input, aug_index=aug_index)
output = self.sharp.forward(input=output, aug_index=aug_index)
return output
class BlurRandpersSharpness(nn.Module):
def __init__(self, max_range=2, kernel_size=3, sigma=(10, 20), distortion_scale=0.6, p=1, sharpness_factor=0.6):
super(BlurRandpersSharpness, self).__init__()
self.max_range = max_range
self.sigma = sigma
self.kernel_size = kernel_size
self.distortion_scale = distortion_scale
self.p = p
self.sharpness_factor = sharpness_factor
self.gauss = GaussBlur(kernel_size=self.kernel_size, sigma=self.sigma)
self.randpers = RandPers(distortion_scale=self.distortion_scale, p=self.p)
self.sharp = RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.p)
def forward(self, input, aug_index=None):
output = self.gauss.forward(input=input, aug_index=aug_index)
output = self.randpers.forward(input=output, aug_index=aug_index)
output = self.sharp.forward(input=output, aug_index=aug_index)
return output
class FourCrop(nn.Module):
def __init__(self, max_range = 4):
super(FourCrop, self).__init__()
self.max_range = max_range
self.prob = 0.5
def forward(self, inputs):
outputs = inputs
for i in range(8):
outputs[i] = self._crop(inputs.size(), inputs[i], i)
return outputs
def _crop(self, size, input, i):
_, _, H, W = size
h_mid = int(H / 2)
w_mid = int(W / 2)
if i == 0 or i == 4:
corner = input[:, 0:h_mid, 0:w_mid]
elif i == 1 or i == 5:
corner = input[:, 0:h_mid, w_mid:]
elif i == 2 or i == 6:
corner = input[:, h_mid:, 0:w_mid]
elif i == 3 or i == 7:
corner = input[:, h_mid:, w_mid:]
else:
corner = input
corner = transforms.Resize(size=2*h_mid)(corner)
return corner
class CutPerm(nn.Module):
def __init__(self, max_range = 4):
super(CutPerm, self).__init__()
self.max_range = max_range
self.prob = 0.5
def forward(self, input, aug_index=None):
_device = input.device
_, _, H, W = input.size()
if aug_index is None:
aug_index = np.random.randint(4)
output = self._cutperm(input, aug_index)
_prob = input.new_full((input.size(0),), self.prob)
_mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
output = _mask * input + (1 - _mask) * output
else:
aug_index = aug_index % self.max_range
output = self._cutperm(input, aug_index)
return output
def _cutperm(self, inputs, aug_index):
_, _, H, W = inputs.size()
h_mid = int(H / 2)
w_mid = int(W / 2)
jigsaw_h = aug_index // 2
jigsaw_v = aug_index % 2
if jigsaw_h == 1:
inputs = torch.cat((inputs[:, :, h_mid:, :], inputs[:, :, 0:h_mid, :]), dim=2)
if jigsaw_v == 1:
inputs = torch.cat((inputs[:, :, :, w_mid:], inputs[:, :, :, 0:w_mid]), dim=3)
return inputs
def assemble(a, b, c, d):
ab = torch.cat((a, b), dim=2)
cd = torch.cat((c, d), dim=2)
output = torch.cat((ab, cd), dim=3)
return output
def quarter(inputs):
_, _, H, W = inputs.size()
h_mid = int(H / 2)
w_mid = int(W / 2)
quarters = []
quarters.append(inputs[:, :, 0:h_mid, 0:w_mid])
quarters.append(inputs[:, :, 0:h_mid, w_mid:])
quarters.append(inputs[:, :, h_mid:, 0:w_mid])
quarters.append(inputs[:, :, h_mid:, w_mid:])
return quarters
class HorizontalFlipLayer(nn.Module):
def __init__(self):
"""
img_size : (int, int, int)
Height and width must be powers of 2. E.g. (32, 32, 1) or
(64, 128, 3). Last number indicates number of channels, e.g. 1 for
grayscale or 3 for RGB
"""
super(HorizontalFlipLayer, self).__init__()
_eye = torch.eye(2, 3)
self.register_buffer('_eye', _eye)
def forward(self, inputs):
_device = inputs.device
N = inputs.size(0)
_theta = self._eye.repeat(N, 1, 1)
r_sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1
_theta[:, 0, 0] = r_sign
grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device)
inputs = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs)
return inputs
class RandomColorGrayLayer(nn.Module):
def __init__(self, p):
super(RandomColorGrayLayer, self).__init__()
self.prob = p
_weight = torch.tensor([[0.299, 0.587, 0.114]])
self.register_buffer('_weight', _weight.view(1, 3, 1, 1))
def forward(self, inputs, aug_index=None):
if aug_index == 0:
return inputs
l = F.conv2d(inputs, self._weight)
gray = torch.cat([l, l, l], dim=1)
if aug_index is None:
_prob = inputs.new_full((inputs.size(0),), self.prob)
_mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
gray = inputs * (1 - _mask) + gray * _mask
return gray
class ColorJitterLayer(nn.Module):
def __init__(self, p, brightness, contrast, saturation, hue):
super(ColorJitterLayer, self).__init__()
self.prob = p
self.brightness = self._check_input(brightness, 'brightness')
self.contrast = self._check_input(contrast, 'contrast')
self.saturation = self._check_input(saturation, 'saturation')
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
clip_first_on_zero=False)
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
if isinstance(value, numbers.Number):
if value < 0:
raise ValueError("If {} is a single number, it must be non negative.".format(name))
value = [center - value, center + value]
if clip_first_on_zero:
value[0] = max(value[0], 0)
elif isinstance(value, (tuple, list)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError("{} values should be between {}".format(name, bound))
else:
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
value = None
return value
def adjust_contrast(self, x):
if self.contrast:
factor = x.new_empty(x.size(0), 1, 1, 1).uniform_(*self.contrast)
means = torch.mean(x, dim=[2, 3], keepdim=True)
x = (x - means) * factor + means
return torch.clamp(x, 0, 1)
def adjust_hsv(self, x):
f_h = x.new_zeros(x.size(0), 1, 1)
f_s = x.new_ones(x.size(0), 1, 1)
f_v = x.new_ones(x.size(0), 1, 1)
if self.hue:
f_h.uniform_(*self.hue)
if self.saturation:
f_s = f_s.uniform_(*self.saturation)
if self.brightness:
f_v = f_v.uniform_(*self.brightness)
return RandomHSVFunction.apply(x, f_h, f_s, f_v)
def transform(self, inputs):
# Shuffle transform
if np.random.rand() > 0.5:
transforms = [self.adjust_contrast, self.adjust_hsv]
else:
transforms = [self.adjust_hsv, self.adjust_contrast]
for t in transforms:
inputs = t(inputs)
return inputs
def forward(self, inputs):
_prob = inputs.new_full((inputs.size(0),), self.prob)
_mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
return inputs * (1 - _mask) + self.transform(inputs) * _mask
class RandomHSVFunction(Function):
@staticmethod
def forward(ctx, x, f_h, f_s, f_v):
# ctx is a context object that can be used to stash information
# for backward computation
x = rgb2hsv(x)
h = x[:, 0, :, :]
h += (f_h * 255. / 360.)
h = (h % 1)
x[:, 0, :, :] = h
x[:, 1, :, :] = x[:, 1, :, :] * f_s
x[:, 2, :, :] = x[:, 2, :, :] * f_v
x = torch.clamp(x, 0, 1)
x = hsv2rgb(x)
return x
@staticmethod
def backward(ctx, grad_output):
# We return as many input gradients as there were arguments.
# Gradients of non-Tensor arguments to forward must be None.
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.clone()
return grad_input, None, None, None
class NormalizeLayer(nn.Module):
"""
In order to certify radii in original coordinates rather than standardized coordinates, we
add the Gaussian noise _before_ standardizing, which is why we have standardization be the first
layer of the classifier rather than as a part of preprocessing as is typical.
"""
def __init__(self):
super(NormalizeLayer, self).__init__()
def forward(self, inputs):
return (inputs - 0.5) / 0.5

1799
train.ipynb Normal file

File diff suppressed because it is too large Load Diff

57
train.py Normal file
View File

@ -0,0 +1,57 @@
from utils.utils import Logger
from utils.utils import save_checkpoint
from utils.utils import save_linear_checkpoint
from common.train import *
from evals import test_classifier
if 'sup' in P.mode:
from training.sup import setup
else:
from training.unsup import setup
train, fname = setup(P.mode, P)
logger = Logger(fname, ask=not resume, local_rank=P.local_rank)
logger.log(P)
logger.log(model)
if P.multi_gpu:
linear = model.module.linear
else:
linear = model.linear
linear_optim = torch.optim.Adam(linear.parameters(), lr=1e-3, betas=(.9, .999), weight_decay=P.weight_decay)
# Run experiments
for epoch in range(start_epoch, P.epochs + 1):
logger.log_dirname(f"Epoch {epoch}")
model.train()
if P.multi_gpu:
train_sampler.set_epoch(epoch)
kwargs = {}
kwargs['linear'] = linear
kwargs['linear_optim'] = linear_optim
kwargs['simclr_aug'] = simclr_aug
train(P, epoch, model, criterion, optimizer, scheduler_warmup, train_loader, logger=logger, **kwargs)
model.eval()
if epoch % P.save_step == 0 and P.local_rank == 0:
if P.multi_gpu:
save_states = model.module.state_dict()
else:
save_states = model.state_dict()
save_checkpoint(epoch, save_states, optimizer.state_dict(), logger.logdir)
save_linear_checkpoint(linear_optim.state_dict(), logger.logdir)
if epoch % P.error_step == 0 and ('sup' in P.mode):
error = test_classifier(P, model, test_loader, epoch, logger=logger)
is_best = (best > error)
if is_best:
best = error
logger.scalar_summary('eval/best_error', best, epoch)
logger.log('[Epoch %3d] [Test %5.2f] [Best %5.2f]' % (epoch, error, best))

97
training/__init__.py Normal file
View File

@ -0,0 +1,97 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def update_learning_rate(P, optimizer, cur_epoch, n, n_total):
cur_epoch = cur_epoch - 1
lr = P.lr_init
if P.optimizer == 'sgd' or 'lars':
DECAY_RATIO = 0.1
elif P.optimizer == 'adam':
DECAY_RATIO = 0.3
else:
raise NotImplementedError()
if P.warmup > 0:
cur_iter = cur_epoch * n_total + n
if cur_iter <= P.warmup:
lr *= cur_iter / float(P.warmup)
if cur_epoch >= 0.5 * P.epochs:
lr *= DECAY_RATIO
if cur_epoch >= 0.75 * P.epochs:
lr *= DECAY_RATIO
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
def _cross_entropy(input, targets, reduction='mean'):
targets_prob = F.softmax(targets, dim=1)
xent = (-targets_prob * F.log_softmax(input, dim=1)).sum(1)
if reduction == 'sum':
return xent.sum()
elif reduction == 'mean':
return xent.mean()
elif reduction == 'none':
return xent
else:
raise NotImplementedError()
def _entropy(input, reduction='mean'):
return _cross_entropy(input, input, reduction)
def cross_entropy_soft(input, targets, reduction='mean'):
targets_prob = F.softmax(targets, dim=1)
xent = (-targets_prob * F.log_softmax(input, dim=1)).sum(1)
if reduction == 'sum':
return xent.sum()
elif reduction == 'mean':
return xent.mean()
elif reduction == 'none':
return xent
else:
raise NotImplementedError()
def kl_div(input, targets, reduction='batchmean'):
return F.kl_div(F.log_softmax(input, dim=1), F.softmax(targets, dim=1),
reduction=reduction)
def target_nll_loss(inputs, targets, reduction='none'):
inputs_t = -F.nll_loss(inputs, targets, reduction='none')
logit_diff = inputs - inputs_t.view(-1, 1)
logit_diff = logit_diff.scatter(1, targets.view(-1, 1), -1e8)
diff_max = logit_diff.max(1)[0]
if reduction == 'sum':
return diff_max.sum()
elif reduction == 'mean':
return diff_max.mean()
elif reduction == 'none':
return diff_max
else:
raise NotImplementedError()
def target_nll_c(inputs, targets, reduction='none'):
conf = torch.softmax(inputs, dim=1)
conf_t = -F.nll_loss(conf, targets, reduction='none')
conf_diff = conf - conf_t.view(-1, 1)
conf_diff = conf_diff.scatter(1, targets.view(-1, 1), -1)
diff_max = conf_diff.max(1)[0]
if reduction == 'sum':
return diff_max.sum()
elif reduction == 'mean':
return diff_max.mean()
elif reduction == 'none':
return diff_max
else:
raise NotImplementedError()

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,79 @@
import torch
import torch.distributed as dist
import diffdist.functional as distops
def get_similarity_matrix(outputs, chunk=2, multi_gpu=False):
'''
Compute similarity matrix
- outputs: (B', d) tensor for B' = B * chunk
- sim_matrix: (B', B') tensor
'''
if multi_gpu:
outputs_gathered = []
for out in outputs.chunk(chunk):
gather_t = [torch.empty_like(out) for _ in range(dist.get_world_size())]
gather_t = torch.cat(distops.all_gather(gather_t, out))
outputs_gathered.append(gather_t)
outputs = torch.cat(outputs_gathered)
sim_matrix = torch.mm(outputs, outputs.t()) # (B', d), (d, B') -> (B', B')
return sim_matrix
def NT_xent(sim_matrix, temperature=0.5, chunk=2, eps=1e-8):
'''
Compute NT_xent loss
- sim_matrix: (B', B') tensor for B' = B * chunk (first 2B are pos samples)
'''
device = sim_matrix.device
B = sim_matrix.size(0) // chunk # B = B' / chunk
eye = torch.eye(B * chunk).to(device) # (B', B')
sim_matrix = torch.exp(sim_matrix / temperature) * (1 - eye) # remove diagonal
denom = torch.sum(sim_matrix, dim=1, keepdim=True)
sim_matrix = -torch.log(sim_matrix / (denom + eps) + eps) # loss matrix
loss = torch.sum(sim_matrix[:B, B:].diag() + sim_matrix[B:, :B].diag()) / (2 * B)
return loss
def Supervised_NT_xent(sim_matrix, labels, temperature=0.5, chunk=2, eps=1e-8, multi_gpu=False):
'''
Compute NT_xent loss
- sim_matrix: (B', B') tensor for B' = B * chunk (first 2B are pos samples)
'''
device = sim_matrix.device
if multi_gpu:
gather_t = [torch.empty_like(labels) for _ in range(dist.get_world_size())]
labels = torch.cat(distops.all_gather(gather_t, labels))
labels = labels.repeat(2)
logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True)
sim_matrix = sim_matrix - logits_max.detach()
B = sim_matrix.size(0) // chunk # B = B' / chunk
eye = torch.eye(B * chunk).to(device) # (B', B')
sim_matrix = torch.exp(sim_matrix / temperature) * (1 - eye) # remove diagonal
denom = torch.sum(sim_matrix, dim=1, keepdim=True)
sim_matrix = -torch.log(sim_matrix / (denom + eps) + eps) # loss matrix
labels = labels.contiguous().view(-1, 1)
Mask = torch.eq(labels, labels.t()).float().to(device)
#Mask = eye * torch.stack([labels == labels[i] for i in range(labels.size(0))]).float().to(device)
Mask = Mask / (Mask.sum(dim=1, keepdim=True) + eps)
loss = torch.sum(Mask * sim_matrix) / (2 * B)
return loss

63
training/scheduler.py Normal file
View File

@ -0,0 +1,63 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
class GradualWarmupScheduler(_LRScheduler):
""" Gradually warm-up(increasing) learning rate in optimizer.
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
Args:
optimizer (Optimizer): Wrapped optimizer.
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
total_epoch: target learning rate is reached at total_epoch, gradually
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
"""
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
self.multiplier = multiplier
if self.multiplier < 1.:
raise ValueError('multiplier should be greater thant or equal to 1.')
self.total_epoch = total_epoch
self.after_scheduler = after_scheduler
self.finished = False
super(GradualWarmupScheduler, self).__init__(optimizer)
def get_lr(self):
if self.last_epoch > self.total_epoch:
if self.after_scheduler:
if not self.finished:
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
self.finished = True
return self.after_scheduler.get_lr()
return [base_lr * self.multiplier for base_lr in self.base_lrs]
if self.multiplier == 1.0:
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
else:
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
def step_ReduceLROnPlateau(self, metrics, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
if self.last_epoch <= self.total_epoch:
warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
param_group['lr'] = lr
else:
if epoch is None:
self.after_scheduler.step(metrics, None)
else:
self.after_scheduler.step(metrics, epoch - self.total_epoch)
def step(self, epoch=None, metrics=None):
if type(self.after_scheduler) != ReduceLROnPlateau:
if self.finished and self.after_scheduler:
if epoch is None:
self.after_scheduler.step(None)
else:
self.after_scheduler.step(epoch - self.total_epoch)
else:
return super(GradualWarmupScheduler, self).step(epoch)
else:
self.step_ReduceLROnPlateau(metrics, epoch)

33
training/sup/__init__.py Normal file
View File

@ -0,0 +1,33 @@
def setup(mode, P):
fname = f'{P.dataset}_{P.model}_{mode}_{P.res}'
if mode == 'sup_linear':
from .sup_linear import train
elif mode == 'sup_CSI_linear':
from .sup_CSI_linear import train
elif mode == 'sup_simclr':
from .sup_simclr import train
elif mode == 'sup_simclr_CSI':
assert P.batch_size == 32
# currently only support rotation
from .sup_simclr_CSI import train
else:
raise NotImplementedError()
if P.suffix is not None:
fname += f'_{P.suffix}'
return train, fname
def update_comp_loss(loss_dict, loss_in, loss_out, loss_diff, batch_size):
loss_dict['pos'].update(loss_in, batch_size)
loss_dict['neg'].update(loss_out, batch_size)
loss_dict['diff'].update(loss_diff, batch_size)
def summary_comp_loss(logger, tag, loss_dict, epoch):
logger.scalar_summary(f'{tag}/pos', loss_dict['pos'].average, epoch)
logger.scalar_summary(f'{tag}/neg', loss_dict['neg'].average, epoch)
logger.scalar_summary(f'{tag}', loss_dict['diff'].average, epoch)

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,130 @@
import time
import torch.optim
import torch.optim.lr_scheduler as lr_scheduler
import models.transform_layers as TL
from utils.utils import AverageMeter, normalize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hflip = TL.HorizontalFlipLayer().to(device)
def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None,
simclr_aug=None, linear=None, linear_optim=None):
if P.multi_gpu:
rotation_linear = model.module.shift_cls_layer
joint_linear = model.module.joint_distribution_layer
else:
rotation_linear = model.shift_cls_layer
joint_linear = model.joint_distribution_layer
if epoch == 1:
# define optimizer and save in P (argument)
milestones = [int(0.6 * P.epochs), int(0.75 * P.epochs), int(0.9 * P.epochs)]
linear_optim = torch.optim.SGD(linear.parameters(),
lr=1e-1, weight_decay=P.weight_decay)
P.linear_optim = linear_optim
P.linear_scheduler = lr_scheduler.MultiStepLR(P.linear_optim, gamma=0.1, milestones=milestones)
rotation_linear_optim = torch.optim.SGD(rotation_linear.parameters(),
lr=1e-1, weight_decay=P.weight_decay)
P.rotation_linear_optim = rotation_linear_optim
P.rot_scheduler = lr_scheduler.MultiStepLR(P.rotation_linear_optim, gamma=0.1, milestones=milestones)
joint_linear_optim = torch.optim.SGD(joint_linear.parameters(),
lr=1e-1, weight_decay=P.weight_decay)
P.joint_linear_optim = joint_linear_optim
P.joint_scheduler = lr_scheduler.MultiStepLR(P.joint_linear_optim, gamma=0.1, milestones=milestones)
if logger is None:
log_ = print
else:
log_ = logger.log
batch_time = AverageMeter()
data_time = AverageMeter()
losses = dict()
losses['cls'] = AverageMeter()
losses['rot'] = AverageMeter()
check = time.time()
for n, (images, labels) in enumerate(loader):
model.eval()
count = n * P.n_gpus # number of trained samples
data_time.update(time.time() - check)
check = time.time()
### SimCLR loss ###
if P.dataset != 'imagenet':
batch_size = images.size(0)
images = images.to(device)
images = hflip(images) # 2B with hflip
else:
batch_size = images[0].size(0)
images = images[0].to(device)
labels = labels.to(device)
images = torch.cat([torch.rot90(images, rot, (2, 3)) for rot in range(4)]) # 4B
rot_labels = torch.cat([torch.ones_like(labels) * k for k in range(4)], 0) # B -> 4B
joint_labels = torch.cat([labels + P.n_classes * i for i in range(4)], dim=0)
images = simclr_aug(images) # simclr augmentation
_, outputs_aux = model(images, penultimate=True)
penultimate = outputs_aux['penultimate'].detach()
outputs = linear(penultimate[0:batch_size]) # only use 0 degree samples for linear eval
outputs_rot = rotation_linear(penultimate)
outputs_joint = joint_linear(penultimate)
loss_ce = criterion(outputs, labels)
loss_rot = criterion(outputs_rot, rot_labels)
loss_joint = criterion(outputs_joint, joint_labels)
### CE loss ###
P.linear_optim.zero_grad()
loss_ce.backward()
P.linear_optim.step()
### Rot loss ###
P.rotation_linear_optim.zero_grad()
loss_rot.backward()
P.rotation_linear_optim.step()
### Joint loss ###
P.joint_linear_optim.zero_grad()
loss_joint.backward()
P.joint_linear_optim.step()
### optimizer learning rate ###
lr = P.linear_optim.param_groups[0]['lr']
batch_time.update(time.time() - check)
### Log losses ###
losses['cls'].update(loss_ce.item(), batch_size)
losses['rot'].update(loss_rot.item(), batch_size)
if count % 50 == 0:
log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n'
'[LossC %f] [LossR %f]' %
(epoch, count, batch_time.value, data_time.value, lr,
losses['cls'].value, losses['rot'].value))
check = time.time()
P.linear_scheduler.step()
P.rot_scheduler.step()
P.joint_scheduler.step()
log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f] [LossR %f]' %
(batch_time.average, data_time.average,
losses['cls'].average, losses['rot'].average))
if logger is not None:
logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch)
logger.scalar_summary('train/loss_rot', losses['rot'].average, epoch)
logger.scalar_summary('train/batch_time', batch_time.average, epoch)

View File

@ -0,0 +1,91 @@
import time
import torch.optim
import torch.optim.lr_scheduler as lr_scheduler
import models.transform_layers as TL
from utils.utils import AverageMeter, normalize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hflip = TL.HorizontalFlipLayer().to(device)
def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None,
simclr_aug=None, linear=None, linear_optim=None):
if epoch == 1:
# define optimizer and save in P (argument)
milestones = [int(0.6 * P.epochs), int(0.75 * P.epochs), int(0.9 * P.epochs)]
linear_optim = torch.optim.SGD(linear.parameters(),
lr=1e-1, weight_decay=P.weight_decay)
P.linear_optim = linear_optim
P.linear_scheduler = lr_scheduler.MultiStepLR(P.linear_optim, gamma=0.1, milestones=milestones)
if logger is None:
log_ = print
else:
log_ = logger.log
batch_time = AverageMeter()
data_time = AverageMeter()
losses = dict()
losses['cls'] = AverageMeter()
check = time.time()
for n, (images, labels) in enumerate(loader):
model.eval()
count = n * P.n_gpus # number of trained samples
data_time.update(time.time() - check)
check = time.time()
### SimCLR loss ###
if P.dataset != 'imagenet':
batch_size = images.size(0)
images = images.to(device)
images = hflip(images) # 2B with hflip
else:
batch_size = images[0].size(0)
images = images[0].to(device)
labels = labels.to(device)
images = simclr_aug(images) # simclr augmentation
_, outputs_aux = model(images, penultimate=True)
penultimate = outputs_aux['penultimate'].detach()
outputs = linear(penultimate[0:batch_size]) # only use 0 degree samples for linear eval
loss_ce = criterion(outputs, labels)
### CE loss ###
P.linear_optim.zero_grad()
loss_ce.backward()
P.linear_optim.step()
### optimizer learning rate ###
lr = P.linear_optim.param_groups[0]['lr']
batch_time.update(time.time() - check)
### Log losses ###
losses['cls'].update(loss_ce.item(), batch_size)
if count % 50 == 0:
log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n'
'[LossC %f]' %
(epoch, count, batch_time.value, data_time.value, lr,
losses['cls'].value, ))
check = time.time()
P.linear_scheduler.step()
log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f]' %
(batch_time.average, data_time.average,
losses['cls'].average))
if logger is not None:
logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch)
logger.scalar_summary('train/batch_time', batch_time.average, epoch)

104
training/sup/sup_simclr.py Normal file
View File

@ -0,0 +1,104 @@
import time
import torch.optim
import models.transform_layers as TL
from training.contrastive_loss import get_similarity_matrix, Supervised_NT_xent
from utils.utils import AverageMeter, normalize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hflip = TL.HorizontalFlipLayer().to(device)
def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None,
simclr_aug=None, linear=None, linear_optim=None):
assert simclr_aug is not None
assert P.sim_lambda == 1.0
if logger is None:
log_ = print
else:
log_ = logger.log
batch_time = AverageMeter()
data_time = AverageMeter()
losses = dict()
losses['cls'] = AverageMeter()
losses['sim'] = AverageMeter()
losses['simnorm'] = AverageMeter()
check = time.time()
for n, (images, labels) in enumerate(loader):
model.train()
count = n * P.n_gpus # number of trained samples
data_time.update(time.time() - check)
check = time.time()
### SimCLR loss ###
if P.dataset != 'imagenet' and P.dataset != 'CNMC' and P.dataset != 'CNMC_grayscale':
batch_size = images.size(0)
images = images.to(device)
images_pair = hflip(images.repeat(2, 1, 1, 1)) # 2B with hflip
else:
batch_size = images[0].size(0)
images1, images2 = images[0].to(device), images[1].to(device)
images_pair = torch.cat([images1, images2], dim=0) # 2B
labels = labels.to(device)
images_pair = simclr_aug(images_pair) # simclr augmentation
_, outputs_aux = model(images_pair, simclr=True, penultimate=True)
simclr = normalize(outputs_aux['simclr']) # normalize
sim_matrix = get_similarity_matrix(simclr, multi_gpu=P.multi_gpu)
loss_sim = Supervised_NT_xent(sim_matrix, labels=labels, temperature=0.07, multi_gpu=P.multi_gpu) * P.sim_lambda
### total loss ###
loss = loss_sim
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step(epoch - 1 + n / len(loader))
lr = optimizer.param_groups[0]['lr']
batch_time.update(time.time() - check)
### Post-processing stuffs ###
simclr_norm = outputs_aux['simclr'].norm(dim=1).mean()
### Linear evaluation ###
outputs_linear_eval = linear(outputs_aux['penultimate'].detach())
loss_linear = criterion(outputs_linear_eval, labels.repeat(2))
linear_optim.zero_grad()
loss_linear.backward()
linear_optim.step()
### Log losses ###
losses['cls'].update(0, batch_size)
losses['sim'].update(loss_sim.item(), batch_size)
losses['simnorm'].update(simclr_norm.item(), batch_size)
if count % 50 == 0:
log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n'
'[LossC %f] [LossSim %f] [SimNorm %f]' %
(epoch, count, batch_time.value, data_time.value, lr,
losses['cls'].value, losses['sim'].value, losses['simnorm'].value))
check = time.time()
log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f] [LossSim %f] [SimNorm %f]' %
(batch_time.average, data_time.average,
losses['cls'].average, losses['sim'].average, losses['simnorm'].average))
if logger is not None:
logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch)
logger.scalar_summary('train/loss_sim', losses['sim'].average, epoch)
logger.scalar_summary('train/batch_time', batch_time.average, epoch)
logger.scalar_summary('train/simclr_norm', losses['simnorm'].average, epoch)

View File

@ -0,0 +1,111 @@
import time
import torch.optim
import models.transform_layers as TL
from training.contrastive_loss import get_similarity_matrix, Supervised_NT_xent
from utils.utils import AverageMeter, normalize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hflip = TL.HorizontalFlipLayer().to(device)
def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None,
simclr_aug=None, linear=None, linear_optim=None):
# currently only support rotation shifting augmentation
assert simclr_aug is not None
assert P.sim_lambda == 1.0
if logger is None:
log_ = print
else:
log_ = logger.log
batch_time = AverageMeter()
data_time = AverageMeter()
losses = dict()
losses['cls'] = AverageMeter()
losses['sim'] = AverageMeter()
check = time.time()
for n, (images, labels) in enumerate(loader):
model.train()
count = n * P.n_gpus # number of trained samples
data_time.update(time.time() - check)
check = time.time()
### SimCLR loss ###
if P.dataset != 'imagenet' and P.dataset != 'CNMC' and P.dataset != 'CNMC_grayscale':
batch_size = images.size(0)
images = images.to(device)
images1, images2 = hflip(images.repeat(2, 1, 1, 1)).chunk(2) # hflip
else:
batch_size = images[0].size(0)
images1, images2 = images[0].to(device), images[1].to(device)
#print("\nImages" + str(images.shape) + "\n")
images1 = torch.cat([torch.rot90(images1, rot, (2, 3)) for rot in range(4)]) # 4B
images2 = torch.cat([torch.rot90(images2, rot, (2, 3)) for rot in range(4)]) # 4B
images_pair = torch.cat([images1, images2], dim=0) # 8B
labels = labels.to(device)
rot_sim_labels = torch.cat([labels + P.n_classes * i for i in range(4)], dim=0)
rot_sim_labels = rot_sim_labels.to(device)
images_pair = simclr_aug(images_pair) # simclr augment
_, outputs_aux = model(images_pair, simclr=True, penultimate=True)
simclr = normalize(outputs_aux['simclr']) # normalize
sim_matrix = get_similarity_matrix(simclr, multi_gpu=P.multi_gpu)
loss_sim = Supervised_NT_xent(sim_matrix, labels=rot_sim_labels,
temperature=0.07, multi_gpu=P.multi_gpu) * P.sim_lambda
### total loss ###
loss = loss_sim
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step(epoch - 1 + n / len(loader))
lr = optimizer.param_groups[0]['lr']
batch_time.update(time.time() - check)
### Post-processing stuffs ###
penul_1 = outputs_aux['penultimate'][:batch_size]
penul_2 = outputs_aux['penultimate'][4 * batch_size: 5 * batch_size]
outputs_aux['penultimate'] = torch.cat([penul_1, penul_2]) # only use original rotation
### Linear evaluation ###
outputs_linear_eval = linear(outputs_aux['penultimate'].detach())
loss_linear = criterion(outputs_linear_eval, labels.repeat(2))
linear_optim.zero_grad()
loss_linear.backward()
linear_optim.step()
### Log losses ###
losses['cls'].update(0, batch_size)
losses['sim'].update(loss_sim.item(), batch_size)
if count % 50 == 0:
log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n'
'[LossC %f] [LossSim %f]' %
(epoch, count, batch_time.value, data_time.value, lr,
losses['cls'].value, losses['sim'].value))
check = time.time()
log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f] [LossSim %f]' %
(batch_time.average, data_time.average,
losses['cls'].average, losses['sim'].average))
if logger is not None:
logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch)
logger.scalar_summary('train/loss_sim', losses['sim'].average, epoch)
logger.scalar_summary('train/batch_time', batch_time.average, epoch)

View File

@ -0,0 +1,39 @@
def setup(mode, P):
fname = f'{P.dataset}_{P.model}_unsup_{mode}_{P.res}'
if mode == 'simclr':
from .simclr import train
elif mode == 'simclr_CSI':
from .simclr_CSI import train
fname += f'_shift_{P.shift_trans_type}_resize_factor{P.resize_factor}_color_dist{P.color_distort}'
if P.shift_trans_type == 'gauss':
fname += f'_gauss_sigma{P.gauss_sigma}'
elif P.shift_trans_type == 'randpers':
fname += f'_distortion_scale{P.distortion_scale}'
elif P.shift_trans_type == 'sharp':
fname += f'_sharpness_factor{P.sharpness_factor}'
elif P.shift_trans_type == 'sharp':
fname += f'_nmean_{P.noise_mean}_nstd_{P.noise_std}'
else:
raise NotImplementedError()
if P.one_class_idx is not None:
fname += f'_one_class_{P.one_class_idx}'
if P.suffix is not None:
fname += f'_{P.suffix}'
return train, fname
def update_comp_loss(loss_dict, loss_in, loss_out, loss_diff, batch_size):
loss_dict['pos'].update(loss_in, batch_size)
loss_dict['neg'].update(loss_out, batch_size)
loss_dict['diff'].update(loss_diff, batch_size)
def summary_comp_loss(logger, tag, loss_dict, epoch):
logger.scalar_summary(f'{tag}/pos', loss_dict['pos'].average, epoch)
logger.scalar_summary(f'{tag}/neg', loss_dict['neg'].average, epoch)
logger.scalar_summary(f'{tag}', loss_dict['diff'].average, epoch)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

101
training/unsup/simclr.py Normal file
View File

@ -0,0 +1,101 @@
import time
import torch.optim
import models.transform_layers as TL
from training.contrastive_loss import get_similarity_matrix, NT_xent
from utils.utils import AverageMeter, normalize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hflip = TL.HorizontalFlipLayer().to(device)
def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None,
simclr_aug=None, linear=None, linear_optim=None):
assert simclr_aug is not None
assert P.sim_lambda == 1.0
if logger is None:
log_ = print
else:
log_ = logger.log
batch_time = AverageMeter()
data_time = AverageMeter()
losses = dict()
losses['cls'] = AverageMeter()
losses['sim'] = AverageMeter()
check = time.time()
for n, (images, labels) in enumerate(loader):
model.train()
count = n * P.n_gpus # number of trained samples
data_time.update(time.time() - check)
check = time.time()
### SimCLR loss ###
if P.dataset != 'imagenet':
batch_size = images.size(0)
images = images.to(device)
images_pair = hflip(images.repeat(2, 1, 1, 1)) # 2B with hflip
else:
batch_size = images[0].size(0)
images1, images2 = images[0].to(device), images[1].to(device)
images_pair = torch.cat([images1, images2], dim=0) # 2B
labels = labels.to(device)
images_pair = simclr_aug(images_pair) # transform
_, outputs_aux = model(images_pair, simclr=True, penultimate=True)
simclr = normalize(outputs_aux['simclr']) # normalize
sim_matrix = get_similarity_matrix(simclr, multi_gpu=P.multi_gpu)
loss_sim = NT_xent(sim_matrix, temperature=0.5) * P.sim_lambda
### total loss ###
loss = loss_sim
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step(epoch - 1 + n / len(loader))
lr = optimizer.param_groups[0]['lr']
batch_time.update(time.time() - check)
### Post-processing stuffs ###
simclr_norm = outputs_aux['simclr'].norm(dim=1).mean()
### Linear evaluation ###
outputs_linear_eval = linear(outputs_aux['penultimate'].detach())
loss_linear = criterion(outputs_linear_eval, labels.repeat(2))
linear_optim.zero_grad()
loss_linear.backward()
linear_optim.step()
### Log losses ###
losses['cls'].update(0, batch_size)
losses['sim'].update(loss_sim.item(), batch_size)
if count % 50 == 0:
log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n'
'[LossC %f] [LossSim %f]' %
(epoch, count, batch_time.value, data_time.value, lr,
losses['cls'].value, losses['sim'].value))
check = time.time()
log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f] [LossSim %f]' %
(batch_time.average, data_time.average,
losses['cls'].average, losses['sim'].average))
if logger is not None:
logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch)
logger.scalar_summary('train/loss_sim', losses['sim'].average, epoch)
logger.scalar_summary('train/batch_time', batch_time.average, epoch)

View File

@ -0,0 +1,114 @@
import time
import torch.optim
import models.transform_layers as TL
from training.contrastive_loss import get_similarity_matrix, NT_xent
from utils.utils import AverageMeter, normalize
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
hflip = TL.HorizontalFlipLayer().to(device)
def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None,
simclr_aug=None, linear=None, linear_optim=None):
assert simclr_aug is not None
assert P.sim_lambda == 1.0 # to avoid mistake
assert P.K_shift > 1
if logger is None:
log_ = print
else:
log_ = logger.log
batch_time = AverageMeter()
data_time = AverageMeter()
losses = dict()
losses['cls'] = AverageMeter()
losses['sim'] = AverageMeter()
losses['shift'] = AverageMeter()
check = time.time()
for n, (images, labels) in enumerate(loader):
model.train()
count = n * P.n_gpus # number of trained samples
data_time.update(time.time() - check)
check = time.time()
### SimCLR loss ###
if P.dataset != 'imagenet' and P.dataset != 'CNMC' and P.dataset != 'CNMC_grayscale':
batch_size = images.size(0)
images = images.to(device)
images1, images2 = hflip(images.repeat(2, 1, 1, 1)).chunk(2) # hflip
else:
batch_size = images[0].size(0)
images1, images2 = images[0].to(device), images[1].to(device)
labels = labels.to(device)
images1 = torch.cat([P.shift_trans(images1, k) for k in range(P.K_shift)])
images2 = torch.cat([P.shift_trans(images2, k) for k in range(P.K_shift)])
shift_labels = torch.cat([torch.ones_like(labels) * k for k in range(P.K_shift)], 0) # B -> 4B
shift_labels = shift_labels.repeat(2)
images_pair = torch.cat([images1, images2], dim=0) # 8B
images_pair = simclr_aug(images_pair) # transform
_, outputs_aux = model(images_pair, simclr=True, penultimate=True, shift=True)
simclr = normalize(outputs_aux['simclr']) # normalize
sim_matrix = get_similarity_matrix(simclr, multi_gpu=P.multi_gpu)
loss_sim = NT_xent(sim_matrix, temperature=0.5) * P.sim_lambda
loss_shift = criterion(outputs_aux['shift'], shift_labels)
### total loss ###
loss = loss_sim + loss_shift
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step(epoch - 1 + n / len(loader))
lr = optimizer.param_groups[0]['lr']
batch_time.update(time.time() - check)
### Post-processing stuffs ###
simclr_norm = outputs_aux['simclr'].norm(dim=1).mean()
penul_1 = outputs_aux['penultimate'][:batch_size]
penul_2 = outputs_aux['penultimate'][P.K_shift * batch_size: (P.K_shift + 1) * batch_size]
outputs_aux['penultimate'] = torch.cat([penul_1, penul_2]) # only use original rotation
### Linear evaluation ###
outputs_linear_eval = linear(outputs_aux['penultimate'].detach())
loss_linear = criterion(outputs_linear_eval, labels.repeat(2))
linear_optim.zero_grad()
loss_linear.backward()
linear_optim.step()
losses['cls'].update(0, batch_size)
losses['sim'].update(loss_sim.item(), batch_size)
losses['shift'].update(loss_shift.item(), batch_size)
if count % 50 == 0:
log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n'
'[LossC %f] [LossSim %f] [LossShift %f]' %
(epoch, count, batch_time.value, data_time.value, lr,
losses['cls'].value, losses['sim'].value, losses['shift'].value))
log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f] [LossSim %f] [LossShift %f]' %
(batch_time.average, data_time.average,
losses['cls'].average, losses['sim'].average, losses['shift'].average))
if logger is not None:
logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch)
logger.scalar_summary('train/loss_sim', losses['sim'].average, epoch)
logger.scalar_summary('train/loss_shift', losses['shift'].average, epoch)
logger.scalar_summary('train/batch_time', batch_time.average, epoch)

0
utils/__init__.py Normal file
View File

Some files were not shown because too many files have changed in this diff Show More