Browse Source

init

master
Artur Feoktistov 2 years ago
commit
d1ce7b933f
100 changed files with 17144 additions and 0 deletions
  1. 4691
    0
      .ipynb_checkpoints/eval-checkpoint.ipynb
  2. 1799
    0
      .ipynb_checkpoints/train-checkpoint.ipynb
  3. 176
    0
      README.md
  4. 119
    0
      common/LARS.py
  5. 0
    0
      common/__init__.py
  6. BIN
      common/__init__.pyc
  7. BIN
      common/__pycache__/LARS.cpython-37.pyc
  8. BIN
      common/__pycache__/__init__.cpython-36.pyc
  9. BIN
      common/__pycache__/__init__.cpython-37.pyc
  10. BIN
      common/__pycache__/common.cpython-36.pyc
  11. BIN
      common/__pycache__/common.cpython-37.pyc
  12. BIN
      common/__pycache__/eval.cpython-36.pyc
  13. BIN
      common/__pycache__/eval.cpython-37.pyc
  14. 0
    0
      common/__pycache__/eval.cpython-37.pyc.2498080381488
  15. 0
    0
      common/__pycache__/eval.cpython-37.pyc.2731703741232
  16. BIN
      common/__pycache__/train.cpython-36.pyc
  17. BIN
      common/__pycache__/train.cpython-37.pyc
  18. 114
    0
      common/common.py
  19. 81
    0
      common/eval.py
  20. 148
    0
      common/train.py
  21. BIN
      data/ImageNet_FIX.tar.gz
  22. BIN
      data/Imagenet_resize.tar.gz
  23. BIN
      data/LSUN_FIX.tar.gz
  24. BIN
      data/LSUN_resize.tar.gz
  25. 2
    0
      datasets/__init__.py
  26. BIN
      datasets/__pycache__/__init__.cpython-36.pyc
  27. BIN
      datasets/__pycache__/__init__.cpython-37.pyc
  28. BIN
      datasets/__pycache__/datasets.cpython-36.pyc
  29. BIN
      datasets/__pycache__/datasets.cpython-37.pyc
  30. 0
    0
      datasets/__pycache__/datasets.cpython-37.pyc.2427217203392
  31. BIN
      datasets/__pycache__/postprocess_data.cpython-36.pyc
  32. BIN
      datasets/__pycache__/postprocess_data.cpython-37.pyc
  33. BIN
      datasets/__pycache__/prepare_data.cpython-36.pyc
  34. BIN
      datasets/__pycache__/prepare_data.cpython-37.pyc
  35. 361
    0
      datasets/datasets.py
  36. 66
    0
      datasets/imagenet_fix_preprocess.py
  37. 61
    0
      datasets/lsun_fix_preprocess.py
  38. 37
    0
      datasets/postprocess_data.py
  39. 196
    0
      datasets/prepare_data.py
  40. 4691
    0
      eval.ipynb
  41. 57
    0
      eval.py
  42. 1
    0
      evals/__init__.py
  43. BIN
      evals/__pycache__/__init__.cpython-36.pyc
  44. BIN
      evals/__pycache__/__init__.cpython-37.pyc
  45. BIN
      evals/__pycache__/evals.cpython-36.pyc
  46. BIN
      evals/__pycache__/evals.cpython-37.pyc
  47. BIN
      evals/__pycache__/ood_pre.cpython-36.pyc
  48. BIN
      evals/__pycache__/ood_pre.cpython-37.pyc
  49. 201
    0
      evals/evals.py
  50. 242
    0
      evals/ood_pre.py
  51. BIN
      figures/CSI_teaser.png
  52. BIN
      figures/fixed_ood_benchmarks.png
  53. BIN
      figures/shifting_transformations.png
  54. 37
    0
      main.py
  55. 0
    0
      models/__init__.py
  56. BIN
      models/__pycache__/__init__.cpython-36.pyc
  57. BIN
      models/__pycache__/__init__.cpython-37.pyc
  58. BIN
      models/__pycache__/base_model.cpython-36.pyc
  59. BIN
      models/__pycache__/base_model.cpython-37.pyc
  60. BIN
      models/__pycache__/classifier.cpython-36.pyc
  61. BIN
      models/__pycache__/classifier.cpython-37.pyc
  62. BIN
      models/__pycache__/resnet.cpython-36.pyc
  63. BIN
      models/__pycache__/resnet.cpython-37.pyc
  64. BIN
      models/__pycache__/resnet_imagenet.cpython-36.pyc
  65. BIN
      models/__pycache__/resnet_imagenet.cpython-37.pyc
  66. BIN
      models/__pycache__/transform_layers.cpython-36.pyc
  67. BIN
      models/__pycache__/transform_layers.cpython-37.pyc
  68. 48
    0
      models/base_model.py
  69. 135
    0
      models/classifier.py
  70. 189
    0
      models/resnet.py
  71. 231
    0
      models/resnet_imagenet.py
  72. 643
    0
      models/transform_layers.py
  73. 1799
    0
      train.ipynb
  74. 57
    0
      train.py
  75. 97
    0
      training/__init__.py
  76. BIN
      training/__pycache__/__init__.cpython-36.pyc
  77. BIN
      training/__pycache__/__init__.cpython-37.pyc
  78. BIN
      training/__pycache__/contrastive_loss.cpython-36.pyc
  79. BIN
      training/__pycache__/contrastive_loss.cpython-37.pyc
  80. BIN
      training/__pycache__/scheduler.cpython-36.pyc
  81. BIN
      training/__pycache__/scheduler.cpython-37.pyc
  82. 79
    0
      training/contrastive_loss.py
  83. 63
    0
      training/scheduler.py
  84. 33
    0
      training/sup/__init__.py
  85. BIN
      training/sup/__pycache__/__init__.cpython-36.pyc
  86. BIN
      training/sup/__pycache__/sup_simclr.cpython-36.pyc
  87. BIN
      training/sup/__pycache__/sup_simclr_CSI.cpython-36.pyc
  88. 130
    0
      training/sup/sup_CSI_linear.py
  89. 91
    0
      training/sup/sup_linear.py
  90. 104
    0
      training/sup/sup_simclr.py
  91. 111
    0
      training/sup/sup_simclr_CSI.py
  92. 39
    0
      training/unsup/__init__.py
  93. BIN
      training/unsup/__pycache__/__init__.cpython-36.pyc
  94. BIN
      training/unsup/__pycache__/__init__.cpython-37.pyc
  95. BIN
      training/unsup/__pycache__/simclr_CSI.cpython-36.pyc
  96. BIN
      training/unsup/__pycache__/simclr_CSI.cpython-37.pyc
  97. 0
    0
      training/unsup/__pycache__/simclr_CSI.cpython-37.pyc.2078473038560
  98. 101
    0
      training/unsup/simclr.py
  99. 114
    0
      training/unsup/simclr_CSI.py
  100. 0
    0
      utils/__init__.py

+ 4691
- 0
.ipynb_checkpoints/eval-checkpoint.ipynb
File diff suppressed because it is too large
View File


+ 1799
- 0
.ipynb_checkpoints/train-checkpoint.ipynb
File diff suppressed because it is too large
View File


+ 176
- 0
README.md View File

@@ -0,0 +1,176 @@
# CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances

Official PyTorch implementation of
["**CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances**"](
https://arxiv.org/abs/2007.08176) (NeurIPS 2020) by
[Jihoon Tack*](https://jihoontack.github.io),
[Sangwoo Mo*](https://sites.google.com/view/sangwoomo),
[Jongheon Jeong](https://sites.google.com/view/jongheonj),
and [Jinwoo Shin](http://alinlab.kaist.ac.kr/shin.html).

<p align="center">
<img src=figures/shifting_transformations.png width="900">
</p>

## 1. Requirements
### Environments
Currently, requires following packages
- python 3.6+
- torch 1.4+
- torchvision 0.5+
- CUDA 10.1+
- scikit-learn 0.22+
- tensorboard 2.0+
- [torchlars](https://github.com/kakaobrain/torchlars) == 0.1.2
- [pytorch-gradual-warmup-lr](https://github.com/ildoonet/pytorch-gradual-warmup-lr) packages
- [apex](https://github.com/NVIDIA/apex) == 0.1
- [diffdist](https://github.com/ag14774/diffdist) == 0.1

### Datasets
For CIFAR, please download the following datasets to `~/data`.
* [LSUN_resize](https://www.dropbox.com/s/moqh2wh8696c3yl/LSUN_resize.tar.gz),
[ImageNet_resize](https://www.dropbox.com/s/kp3my3412u5k9rl/Imagenet_resize.tar.gz)
* [LSUN_fix](https://drive.google.com/file/d/1KVWj9xpHfVwGcErH5huVujk9snhEGOxE/view?usp=sharing),
[ImageNet_fix](https://drive.google.com/file/d/1sO_-noq10mmziB1ECDyNhD5T4u5otyKA/view?usp=sharing)

For ImageNet-30, please download the following datasets to `~/data`.
* [ImageNet-30-train](https://drive.google.com/file/d/1B5c39Fc3haOPzlehzmpTLz6xLtGyKEy4/view),
[ImageNet-30-test](https://drive.google.com/file/d/13xzVuQMEhSnBRZr-YaaO08coLU2dxAUq/view)
* [CUB-200](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html),
[Stanford Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/),
[Oxford Pets](https://www.robots.ox.ac.uk/~vgg/data/pets/),
[Oxford flowers](https://www.robots.ox.ac.uk/~vgg/data/flowers/),
[Food-101](https://www.kaggle.com/dansbecker/food-101),
[Places-365](http://data.csail.mit.edu/places/places365/val_256.tar),
[Caltech-256](https://www.kaggle.com/jessicali9530/caltech256),
[DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/)

For Food-101, remove hotdog class to avoid overlap.

## 2. Training
Currently, all code examples are assuming distributed launch with 4 multi GPUs.
To run the code with single GPU, remove `-m torch.distributed.launch --nproc_per_node=4`.

### Unlabeled one-class & multi-class
To train unlabeled one-class & multi-class models in the paper, run this command:

```train
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py --dataset <DATASET> --model <NETWORK> --mode simclr_CSI --shift_trans_type rotation --batch_size 32 --one_class_idx <One-Class-Index>
```

> Option --one_class_idx denotes the in-distribution of one-class training.
> For multi-class training, set --one_class_idx as None.
> To run SimCLR simply change --mode to simclr.
> Total batch size should be 512 = 4 (GPU) * 32 (--batch_size option) * 4 (cardinality of shifted transformation set).

### Labeled multi-class
To train labeled multi-class model (confidence calibrated classifier) in the paper, run this command:

```train
# Representation train
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py --dataset <DATASET> --model <NETWORK> --mode sup_simclr_CSI --shift_trans_type rotation --batch_size 32 --epoch 700
# Linear layer train
python train.py --mode sup_CSI_linear --dataset <DATASET> --model <NETWORK> --batch_size 32 --epoch 100 --shift_trans_type rotation --load_path <MODEL_PATH>
```

> To run SupCLR simply change --mode to sup_simclr, sup_linear for representation training and linear layer training respectively.
> Total batch size should be same as above. Currently only supports rotation for shifted transformation.

## 3. Evaluation

We provide the checkpoint of the CSI pre-trained model. Download the checkpoint from the following link:
- One-class CIFAR-10: [ResNet-18](https://drive.google.com/drive/folders/1z02i0G_lzrZe0NwpH-tnjpO8pYHV7mE9?usp=sharing)
- Unlabeled (multi-class) CIFAR-10: [ResNet-18](https://drive.google.com/file/d/1yUq6Si6hWaMa1uYyLDHk0A4BrPIa8ECV/view?usp=sharing)
- Unlabeled (multi-class) ImageNet-30: [ResNet-18](https://drive.google.com/file/d/1KucQWSik8RyoJgU-fz8XLmCWhvMOP7fT/view?usp=sharing)
- Labeled (multi-class) CIFAR-10: [ResNet-18](https://drive.google.com/file/d/1rW2-0MJEzPHLb_PAW-LvCivHt-TkDpRO/view?usp=sharing)

### Unlabeled one-class & multi-class
To evaluate my model on unlabeled one-class & multi-class out-of-distribution (OOD) detection setting, run this command:

```eval
python eval.py --mode ood_pre --dataset <DATASET> --model <NETWORK> --ood_score CSI --shift_trans_type rotation --print_score --ood_samples 10 --resize_factor 0.54 --resize_fix --one_class_idx <One-Class-Index> --load_path <MODEL_PATH>
```

> Option --one_class_idx denotes the in-distribution of one-class evaluation.
> For multi-class evaluation, set --one_class_idx as None.
> The resize_factor & resize fix option fix the cropping size of RandomResizedCrop().
> For SimCLR evaluation, change --ood_score to simclr.

### Labeled multi-class
To evaluate my model on labeled multi-class accuracy, ECE, OOD detection setting, run this command:

```eval
# OOD AUROC
python eval.py --mode ood --ood_score baseline_marginalized --print_score --dataset <DATASET> --model <NETWORK> --shift_trans_type rotation --load_path <MODEL_PATH>
# Accuray & ECE
python eval.py --mode test_marginalized_acc --dataset <DATASET> --model <NETWORK> --shift_trans_type rotation --load_path <MODEL_PATH>
```

> This option is for marginalized inference.
> For single inference (also used for SupCLR) change --ood_score baseline in first command,
> and --mode test_acc in second command.

## 4. Results

Our model achieves the following performance on:

### One-Class Out-of-Distribution Detection

| Method | Dataset | AUROC (Mean) |
| --------------|------------------ | --------------|
| SimCLR | CIFAR-10-OC | 87.9% |
| Rot+Trans | CIFAR-10-OC | 90.0% |
| CSI (ours) | CIFAR-10-OC | 94.3% |

We only show CIFAR-10 one-class result in this repo. For other setting, please see our paper.

### Unlabeled Multi-Class Out-of-Distribution Detection

| Method | Dataset | OOD Dataset | AUROC (Mean) |
| --------------|------------------ |---------------|--------------|
| Rot+Trans | CIFAR-10 | CIFAR-100 | 82.5% |
| CSI (ours) | CIFAR-10 | CIFAR-100 | 89.3% |

We only show CIFAR-10 to CIFAR-100 OOD detection result in this repo. For other OOD dataset results, see our paper.

### Labeled Multi-Class Result

| Method | Dataset | OOD Dataset | Acc | ECE | AUROC (Mean) |
| ---------------- |------------------ |---------------|-------|-------|--------------|
| SupCLR | CIFAR-10 | CIFAR-100 | 93.9% | 5.54% | 88.3% |
| CSI (ours) | CIFAR-10 | CIFAR-100 | 94.8% | 4.24% | 90.6% |
| CSI-ensem (ours) | CIFAR-10 | CIFAR-100 | 96.0% | 3.64% | 92.3% |

We only show CIFAR-10 with CIFAR-100 as OOD in this repo. For other dataset results, please see our paper.

## 5. New OOD dataset

<p align="center">
<img src=figures/fixed_ood_benchmarks.png width="600">
</p>

We find that current benchmark datasets for OOD detection, are visually far from in-distribution datasets (e.g. CIFAR).

To address this issue, we provide new datasets for OOD detection evaluation:
[LSUN_fix](https://drive.google.com/file/d/1KVWj9xpHfVwGcErH5huVujk9snhEGOxE/view?usp=sharing),
[ImageNet_fix](https://drive.google.com/file/d/1sO_-noq10mmziB1ECDyNhD5T4u5otyKA/view?usp=sharing).
See the above figure for the visualization of current benchmark and our dataset.

To generate OOD datasets, run the following codes inside the `./datasets` folder:

```OOD dataset generation
# ImageNet FIX generation code
python imagenet_fix_preprocess.py
# LSUN FIX generation code
python lsun_fix_preprocess.py
```

## Citation
```
@inproceedings{tack2020csi,
title={CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances},
author={Jihoon Tack and Sangwoo Mo and Jongheon Jeong and Jinwoo Shin},
booktitle={Advances in Neural Information Processing Systems},
year={2020}
}
```

+ 119
- 0
common/LARS.py View File

@@ -0,0 +1,119 @@
"""
References:
- https://github.com/PyTorchLightning/PyTorch-Lightning-Bolts/blob/master/pl_bolts/optimizers/lars_scheduling.py
- https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
- https://arxiv.org/pdf/1708.03888.pdf
- https://github.com/noahgolmant/pytorch-lars/blob/master/lars.py
"""

import torch
from .wrapper import OptimWrapper

# from torchlars._adaptive_lr import compute_adaptive_lr # Impossible to build extensions


__all__ = ["LARS"]


class LARS(OptimWrapper):
"""Implements 'LARS (Layer-wise Adaptive Rate Scaling)'__ as Optimizer a
:class:`~torch.optim.Optimizer` wrapper.
__ : https://arxiv.org/abs/1708.03888
Wraps an arbitrary optimizer like :class:`torch.optim.SGD` to use LARS. If
you want to the same performance obtained with small-batch training when
you use large-batch training, LARS will be helpful::
Args:
optimizer (Optimizer):
optimizer to wrap
eps (float, optional):
epsilon to help with numerical stability while calculating the
adaptive learning rate
trust_coef (float, optional):
trust coefficient for calculating the adaptive learning rate
Example::
base_optimizer = optim.SGD(model.parameters(), lr=0.1)
optimizer = LARS(optimizer=base_optimizer)
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
"""

def __init__(self, optimizer, trust_coef=0.02, clip=True, eps=1e-8):
if eps < 0.0:
raise ValueError("invalid epsilon value: , %f" % eps)
if trust_coef < 0.0:
raise ValueError("invalid trust coefficient: %f" % trust_coef)

self.optim = optimizer
self.eps = eps
self.trust_coef = trust_coef
self.clip = clip

def __getstate__(self):
self.optim.__get
lars_dict = {}
lars_dict["trust_coef"] = self.trust_coef
lars_dict["clip"] = self.clip
lars_dict["eps"] = self.eps
return (self.optim, lars_dict)

def __setstate__(self, state):
self.optim, lars_dict = state
self.trust_coef = lars_dict["trust_coef"]
self.clip = lars_dict["clip"]
self.eps = lars_dict["eps"]

@torch.no_grad()
def step(self, closure=None):
weight_decays = []

for group in self.optim.param_groups:
weight_decay = group.get("weight_decay", 0)
weight_decays.append(weight_decay)

# reset weight decay
group["weight_decay"] = 0

# update the parameters
for p in group["params"]:
if p.grad is not None:
self.update_p(p, group, weight_decay)

# update the optimizer
self.optim.step(closure=closure)

# return weight decay control to optimizer
for group_idx, group in enumerate(self.optim.param_groups):
group["weight_decay"] = weight_decays[group_idx]

def update_p(self, p, group, weight_decay):
# calculate new norms
p_norm = torch.norm(p.data)
g_norm = torch.norm(p.grad.data)

if p_norm != 0 and g_norm != 0:
# calculate new lr
divisor = g_norm + p_norm * weight_decay + self.eps
adaptive_lr = (self.trust_coef * p_norm) / divisor

# clip lr
if self.clip:
adaptive_lr = min(adaptive_lr / group["lr"], 1)

# update params with clipped lr
p.grad.data += weight_decay * p.data
p.grad.data *= adaptive_lr


from torch.optim import SGD
from pylot.util import delegates, separate_kwargs


class SGDLARS(LARS):
@delegates(to=LARS.__init__)
@delegates(to=SGD.__init__, keep=True, but=["eps", "trust_coef"])
def __init__(self, params, **kwargs):
sgd_kwargs, lars_kwargs = separate_kwargs(kwargs, SGD.__init__)
optim = SGD(params, **sgd_kwargs)
super().__init__(optim, **lars_kwargs)

+ 0
- 0
common/__init__.py View File


BIN
common/__init__.pyc View File


BIN
common/__pycache__/LARS.cpython-37.pyc View File


BIN
common/__pycache__/__init__.cpython-36.pyc View File


BIN
common/__pycache__/__init__.cpython-37.pyc View File


BIN
common/__pycache__/common.cpython-36.pyc View File


BIN
common/__pycache__/common.cpython-37.pyc View File


BIN
common/__pycache__/eval.cpython-36.pyc View File


BIN
common/__pycache__/eval.cpython-37.pyc View File


+ 0
- 0
common/__pycache__/eval.cpython-37.pyc.2498080381488 View File


+ 0
- 0
common/__pycache__/eval.cpython-37.pyc.2731703741232 View File


BIN
common/__pycache__/train.cpython-36.pyc View File


BIN
common/__pycache__/train.cpython-37.pyc View File


+ 114
- 0
common/common.py View File

@@ -0,0 +1,114 @@
from argparse import ArgumentParser


def parse_args(default=False):
"""Command-line argument parser for training."""

parser = ArgumentParser(description='Pytorch implementation of CSI')

parser.add_argument('--dataset', help='Dataset',
choices=['cifar10', 'cifar100', 'imagenet', 'CNMC', 'CNMC_grayscale'], type=str)
parser.add_argument('--one_class_idx', help='None: multi-class, Not None: one-class',
default=None, type=int)
parser.add_argument('--model', help='Model',
choices=['resnet18', 'resnet18_imagenet'], type=str)
parser.add_argument('--mode', help='Training mode',
default='simclr', type=str)
parser.add_argument('--simclr_dim', help='Dimension of simclr layer',
default=128, type=int)

parser.add_argument('--shift_trans_type', help='shifting transformation type', default='none',
choices=['rotation', 'cutperm', 'blur', 'randpers', 'sharp', 'blur_randpers',
'blur_sharp', 'randpers_sharp', 'blur_randpers_sharp', 'noise', 'none'], type=str)

parser.add_argument("--local_rank", type=int,
default=0, help='Local rank for distributed learning')
parser.add_argument('--resume_path', help='Path to the resume checkpoint',
default=None, type=str)
parser.add_argument('--load_path', help='Path to the loading checkpoint',
default=None, type=str)
parser.add_argument("--no_strict", help='Do not strictly load state_dicts',
action='store_true')
parser.add_argument('--suffix', help='Suffix for the log dir',
default=None, type=str)
parser.add_argument('--error_step', help='Epoch steps to compute errors',
default=5, type=int)
parser.add_argument('--save_step', help='Epoch steps to save models',
default=10, type=int)

##### Training Configurations #####
parser.add_argument('--epochs', help='Epochs',
default=1000, type=int)
parser.add_argument('--optimizer', help='Optimizer',
choices=['sgd', 'lars'],
default='lars', type=str)
parser.add_argument('--lr_scheduler', help='Learning rate scheduler',
choices=['step_decay', 'cosine'],
default='cosine', type=str)
parser.add_argument('--warmup', help='Warm-up epochs',
default=10, type=int)
parser.add_argument('--lr_init', help='Initial learning rate',
default=1e-1, type=float)
parser.add_argument('--weight_decay', help='Weight decay',
default=1e-6, type=float)
parser.add_argument('--batch_size', help='Batch size',
default=128, type=int)
parser.add_argument('--test_batch_size', help='Batch size for test loader',
default=100, type=int)
parser.add_argument('--blur_sigma', help='Distortion grade',
default=2.0, type=float)
parser.add_argument('--color_distort', help='Color distortion grade',
default=0.5, type=float)
parser.add_argument('--distortion_scale', help='Perspective distortion grade',
default=0.6, type=float)
parser.add_argument('--sharpness_factor', help='Sharpening or blurring factor of image. '
'Can be any non negative number. 0 gives a blurred image, '
'1 gives the original image while 2 increases the sharpness '
'by a factor of 2.',
default=2, type=float)
parser.add_argument('--noise_mean', help='mean',
default=0, type=float)
parser.add_argument('--noise_std', help='std',
default=0.3, type=float)

##### Objective Configurations #####
parser.add_argument('--sim_lambda', help='Weight for SimCLR loss',
default=1.0, type=float)
parser.add_argument('--temperature', help='Temperature for similarity',
default=0.5, type=float)

##### Evaluation Configurations #####
parser.add_argument("--ood_dataset", help='Datasets for OOD detection',
default=None, nargs="*", type=str)
parser.add_argument("--ood_score", help='score function for OOD detection',
default=['norm_mean'], nargs="+", type=str)
parser.add_argument("--ood_layer", help='layer for OOD scores',
choices=['penultimate', 'simclr', 'shift'],
default=['simclr', 'shift'], nargs="+", type=str)
parser.add_argument("--ood_samples", help='number of samples to compute OOD score',
default=1, type=int)
parser.add_argument("--ood_batch_size", help='batch size to compute OOD score',
default=100, type=int)
parser.add_argument("--resize_factor", help='resize scale is sampled from [resize_factor, 1.0]',
default=0.08, type=float)
parser.add_argument("--resize_fix", help='resize scale is fixed to resize_factor (not (resize_factor, 1.0])',
action='store_true')

parser.add_argument("--print_score", help='print quantiles of ood score',
action='store_true')
parser.add_argument("--save_score", help='save ood score for plotting histogram',
action='store_true')

##### Process configuration option #####
parser.add_argument("--proc_step", help='choose process to initiate.',
choices=['eval', 'train'],
default=None, type=str)
parser.add_argument("--res", help='resolution of dataset',
default="32px", type=str)

if default:
return parser.parse_args('') # empty string
else:
return parser.parse_args()

+ 81
- 0
common/eval.py View File

@@ -0,0 +1,81 @@
from copy import deepcopy

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from common.common import parse_args
import models.classifier as C
from datasets import get_dataset, get_superclass_list, get_subclass_dataset

P = parse_args()

### Set torch device ###

P.n_gpus = torch.cuda.device_count()
assert P.n_gpus <= 1 # no multi GPU
P.multi_gpu = False

if torch.cuda.is_available():
torch.cuda.set_device(P.local_rank)
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")

### Initialize dataset ###
ood_eval = P.mode == 'ood_pre'
if P.dataset == 'imagenet' and ood_eval or P.dataset == 'CNMC' and ood_eval or P.dataset == 'CNMC_grayscale' and ood_eval:
P.batch_size = 1
P.test_batch_size = 1
train_set, test_set, image_size, n_classes = get_dataset(P, dataset=P.dataset, eval=ood_eval)

P.image_size = image_size
P.n_classes = n_classes

if P.one_class_idx is not None:
cls_list = get_superclass_list(P.dataset)
P.n_superclasses = len(cls_list)

full_test_set = deepcopy(test_set) # test set of full classes
train_set = get_subclass_dataset(train_set, classes=cls_list[P.one_class_idx])
test_set = get_subclass_dataset(test_set, classes=cls_list[P.one_class_idx])

kwargs = {'pin_memory': False, 'num_workers': 2}

train_loader = DataLoader(train_set, shuffle=True, batch_size=P.batch_size, **kwargs)
test_loader = DataLoader(test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs)

if P.ood_dataset is None:
if P.one_class_idx is not None:
P.ood_dataset = list(range(P.n_superclasses))
P.ood_dataset.pop(P.one_class_idx)
elif P.dataset == 'cifar10':
P.ood_dataset = ['svhn', 'lsun_resize', 'imagenet_resize', 'lsun_fix', 'imagenet_fix', 'cifar100', 'interp']
elif P.dataset == 'imagenet':
P.ood_dataset = ['cub', 'stanford_dogs', 'flowers102', 'places365', 'food_101', 'caltech_256', 'dtd', 'pets']

ood_test_loader = dict()
for ood in P.ood_dataset:
if ood == 'interp':
ood_test_loader[ood] = None # dummy loader
continue

if P.one_class_idx is not None:
ood_test_set = get_subclass_dataset(full_test_set, classes=cls_list[ood])
ood = f'one_class_{ood}' # change save name
else:
ood_test_set = get_dataset(P, dataset=ood, test_only=True, image_size=P.image_size, eval=ood_eval)

ood_test_loader[ood] = DataLoader(ood_test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs)

### Initialize model ###

simclr_aug = C.get_simclr_augmentation(P, image_size=P.image_size).to(device)
P.shift_trans, P.K_shift = C.get_shift_module(P, eval=True)
P.shift_trans = P.shift_trans.to(device)

model = C.get_classifier(P.model, n_classes=P.n_classes).to(device)
model = C.get_shift_classifer(model, P.K_shift).to(device)
criterion = nn.CrossEntropyLoss().to(device)

if P.load_path is not None:
checkpoint = torch.load(P.load_path)
model.load_state_dict(checkpoint, strict=not P.no_strict)

+ 148
- 0
common/train.py View File

@@ -0,0 +1,148 @@
from copy import deepcopy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader

from common.common import parse_args
import models.classifier as C
from datasets import get_dataset, get_superclass_list, get_subclass_dataset
from utils.utils import load_checkpoint

P = parse_args()

### Set torch device ###

if torch.cuda.is_available():
torch.cuda.set_device(P.local_rank)
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")

P.n_gpus = torch.cuda.device_count()

if P.n_gpus > 1:
import apex
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

P.multi_gpu = True
torch.distributed.init_process_group(
'nccl',
init_method='env://',
world_size=P.n_gpus,
rank=P.local_rank,
)
else:
P.multi_gpu = False

### only use one ood_layer while training
P.ood_layer = P.ood_layer[0]

### Initialize dataset ###
train_set, test_set, image_size, n_classes = get_dataset(P, dataset=P.dataset)
P.image_size = image_size
P.n_classes = n_classes

if P.one_class_idx is not None:
cls_list = get_superclass_list(P.dataset)
P.n_superclasses = len(cls_list)

full_test_set = deepcopy(test_set) # test set of full classes
train_set = get_subclass_dataset(train_set, classes=cls_list[P.one_class_idx])
test_set = get_subclass_dataset(test_set, classes=cls_list[P.one_class_idx])

kwargs = {'pin_memory': False, 'num_workers': 2}

if P.multi_gpu:
train_sampler = DistributedSampler(train_set, num_replicas=P.n_gpus, rank=P.local_rank)
test_sampler = DistributedSampler(test_set, num_replicas=P.n_gpus, rank=P.local_rank)
train_loader = DataLoader(train_set, sampler=train_sampler, batch_size=P.batch_size, **kwargs)
test_loader = DataLoader(test_set, sampler=test_sampler, batch_size=P.test_batch_size, **kwargs)
else:
train_loader = DataLoader(train_set, shuffle=True, batch_size=P.batch_size, **kwargs)
test_loader = DataLoader(test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs)

if P.ood_dataset is None:
if P.one_class_idx is not None:
P.ood_dataset = list(range(P.n_superclasses))
P.ood_dataset.pop(P.one_class_idx)
elif P.dataset == 'cifar10':
P.ood_dataset = ['svhn', 'lsun_resize', 'imagenet_resize', 'lsun_fix', 'imagenet_fix', 'cifar100', 'interp']
elif P.dataset == 'imagenet':
P.ood_dataset = ['cub', 'stanford_dogs', 'flowers102']

ood_test_loader = dict()
for ood in P.ood_dataset:
if ood == 'interp':
ood_test_loader[ood] = None # dummy loader
continue

if P.one_class_idx is not None:
ood_test_set = get_subclass_dataset(full_test_set, classes=cls_list[ood])
ood = f'one_class_{ood}' # change save name
else:
ood_test_set = get_dataset(P, dataset=ood, test_only=True, image_size=P.image_size)

if P.multi_gpu:
ood_sampler = DistributedSampler(ood_test_set, num_replicas=P.n_gpus, rank=P.local_rank)
ood_test_loader[ood] = DataLoader(ood_test_set, sampler=ood_sampler, batch_size=P.test_batch_size, **kwargs)
else:
ood_test_loader[ood] = DataLoader(ood_test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs)

### Initialize model ###

simclr_aug = C.get_simclr_augmentation(P, image_size=P.image_size).to(device)
P.shift_trans, P.K_shift = C.get_shift_module(P, eval=True)
P.shift_trans = P.shift_trans.to(device)

model = C.get_classifier(P.model, n_classes=P.n_classes).to(device)
model = C.get_shift_classifer(model, P.K_shift).to(device)

criterion = nn.CrossEntropyLoss().to(device)

if P.optimizer == 'sgd':
optimizer = optim.SGD(model.parameters(), lr=P.lr_init, momentum=0.9, weight_decay=P.weight_decay)
lr_decay_gamma = 0.1
elif P.optimizer == 'lars':
from torchlars import LARS
base_optimizer = optim.SGD(model.parameters(), lr=P.lr_init, momentum=0.9, weight_decay=P.weight_decay)
optimizer = LARS(base_optimizer, eps=1e-8, trust_coef=0.001)
lr_decay_gamma = 0.1
else:
raise NotImplementedError()

if P.lr_scheduler == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, P.epochs)
elif P.lr_scheduler == 'step_decay':
milestones = [int(0.5 * P.epochs), int(0.75 * P.epochs)]
scheduler = lr_scheduler.MultiStepLR(optimizer, gamma=lr_decay_gamma, milestones=milestones)
else:
raise NotImplementedError()

from training.scheduler import GradualWarmupScheduler
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=10.0, total_epoch=P.warmup, after_scheduler=scheduler)

if P.resume_path is not None:
resume = True
model_state, optim_state, config = load_checkpoint(P.resume_path, mode='last')
model.load_state_dict(model_state, strict=not P.no_strict)
optimizer.load_state_dict(optim_state)
start_epoch = config['epoch']
best = config['best']
error = 100.0
else:
resume = False
start_epoch = 1
best = 100.0
error = 100.0

if P.mode == 'sup_linear' or P.mode == 'sup_CSI_linear':
assert P.load_path is not None
checkpoint = torch.load(P.load_path)
model.load_state_dict(checkpoint, strict=not P.no_strict)

if P.multi_gpu:
simclr_aug = apex.parallel.DistributedDataParallel(simclr_aug, delay_allreduce=True)
model = apex.parallel.convert_syncbn_model(model)
model = apex.parallel.DistributedDataParallel(model, delay_allreduce=True)

BIN
data/ImageNet_FIX.tar.gz View File


BIN
data/Imagenet_resize.tar.gz View File


BIN
data/LSUN_FIX.tar.gz View File


BIN
data/LSUN_resize.tar.gz View File


+ 2
- 0
datasets/__init__.py View File

@@ -0,0 +1,2 @@
from datasets.datasets import get_dataset, get_superclass_list, get_subclass_dataset


BIN
datasets/__pycache__/__init__.cpython-36.pyc View File


BIN
datasets/__pycache__/__init__.cpython-37.pyc View File


BIN
datasets/__pycache__/datasets.cpython-36.pyc View File


BIN
datasets/__pycache__/datasets.cpython-37.pyc View File


+ 0
- 0
datasets/__pycache__/datasets.cpython-37.pyc.2427217203392 View File


BIN
datasets/__pycache__/postprocess_data.cpython-36.pyc View File


BIN
datasets/__pycache__/postprocess_data.cpython-37.pyc View File


BIN
datasets/__pycache__/prepare_data.cpython-36.pyc View File


BIN
datasets/__pycache__/prepare_data.cpython-37.pyc View File


+ 361
- 0
datasets/datasets.py View File

@@ -0,0 +1,361 @@
import os

import numpy as np
import torch
from torch.utils.data.dataset import Subset
from torchvision import datasets, transforms

from utils.utils import set_random_seed

DATA_PATH = '~/data/'
IMAGENET_PATH = '~/data/ImageNet'
CNMC_PATH = r'~/data/CSI/CNMC_orig'
CNMC_GRAY_PATH = r'~/data/CSI/CNMC_orig_gray'
CNMC_ROT4_PATH = r'~/data/CSI/CNMC_rotated_4'

CIFAR10_SUPERCLASS = list(range(10)) # one class
IMAGENET_SUPERCLASS = list(range(30)) # one class
CNMC_SUPERCLASS = list(range(2)) # one class

STD_RES = 450
STD_CENTER_CROP = 300

CIFAR100_SUPERCLASS = [
[4, 31, 55, 72, 95],
[1, 33, 67, 73, 91],
[54, 62, 70, 82, 92],
[9, 10, 16, 29, 61],
[0, 51, 53, 57, 83],
[22, 25, 40, 86, 87],
[5, 20, 26, 84, 94],
[6, 7, 14, 18, 24],
[3, 42, 43, 88, 97],
[12, 17, 38, 68, 76],
[23, 34, 49, 60, 71],
[15, 19, 21, 32, 39],
[35, 63, 64, 66, 75],
[27, 45, 77, 79, 99],
[2, 11, 36, 46, 98],
[28, 30, 44, 78, 93],
[37, 50, 65, 74, 80],
[47, 52, 56, 59, 96],
[8, 13, 48, 58, 90],
[41, 69, 81, 85, 89],
]


class MultiDataTransform(object):
def __init__(self, transform):
self.transform1 = transform
self.transform2 = transform

def __call__(self, sample):
x1 = self.transform1(sample)
x2 = self.transform2(sample)
return x1, x2


class MultiDataTransformList(object):
def __init__(self, transform, clean_trasform, sample_num):
self.transform = transform
self.clean_transform = clean_trasform
self.sample_num = sample_num

def __call__(self, sample):
set_random_seed(0)

sample_list = []
for i in range(self.sample_num):
sample_list.append(self.transform(sample))

return sample_list, self.clean_transform(sample)


def get_transform(image_size=None):
# Note: data augmentation is implemented in the layers
# Hence, we only define the identity transformation here
if image_size: # use pre-specified image size
train_transform = transforms.Compose([
transforms.Resize((image_size[0], image_size[1])),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.Resize((image_size[0], image_size[1])),
transforms.ToTensor(),
])
else: # use default image size
train_transform = transforms.Compose([
transforms.ToTensor(),
])
test_transform = transforms.ToTensor()

return train_transform, test_transform


def get_subset_with_len(dataset, length, shuffle=False):
set_random_seed(0)
dataset_size = len(dataset)

index = np.arange(dataset_size)
if shuffle:
np.random.shuffle(index)

index = torch.from_numpy(index[0:length])
subset = Subset(dataset, index)

assert len(subset) == length

return subset


def get_transform_imagenet():

train_transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
])

train_transform = MultiDataTransform(train_transform)

return train_transform, test_transform

def get_transform_cnmc(res, center_crop_size):
train_transform = transforms.Compose([
transforms.Resize(res),
transforms.CenterCrop(center_crop_size),
transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.Resize(res),
transforms.CenterCrop(center_crop_size),
transforms.ToTensor(),
])
train_transform = MultiDataTransform(train_transform)

return train_transform, test_transform


def get_dataset(P, dataset, test_only=False, image_size=None, download=False, eval=False):
if P.res != '':
res = int(P.res.replace('px', ''))
size_factor = int(STD_RES/res) # always remove same portion
center_crop_size = int(STD_CENTER_CROP/size_factor) # remove black border
if dataset in ['CNMC', 'CNMC_grayscale', 'CNMC_ROT4_PATH']:
if eval:
train_transform, test_transform = get_simclr_eval_transform_cnmc(P.ood_samples,
P.resize_factor, P.resize_fix, res, center_crop_size)
else:
train_transform, test_transform = get_transform_cnmc(res, center_crop_size)
elif dataset in ['imagenet', 'cub', 'stanford_dogs', 'flowers102',
'places365', 'food_101', 'caltech_256', 'dtd', 'pets']:
if eval:
train_transform, test_transform = get_simclr_eval_transform_imagenet(P.ood_samples,
P.resize_factor, P.resize_fix)
else:
train_transform, test_transform = get_transform_imagenet()
else:
train_transform, test_transform = get_transform(image_size=image_size)

if dataset == 'CNMC':
image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3
n_classes = 2
train_dir = os.path.join(CNMC_PATH, '0_training')
test_dir = os.path.join(CNMC_PATH, '1_validation')
train_set = datasets.ImageFolder(train_dir, transform=train_transform)
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
elif dataset == 'CNMC_grayscale':
image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3
n_classes = 2
train_dir = os.path.join(CNMC_GRAY_PATH, '0_training')
test_dir = os.path.join(CNMC_GRAY_PATH, '1_validation')
train_set = datasets.ImageFolder(train_dir, transform=train_transform)
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
elif dataset == 'cifar10':
image_size = (32, 32, 3)
n_classes = 10
train_set = datasets.CIFAR10(DATA_PATH, train=True, download=download, transform=train_transform)
test_set = datasets.CIFAR10(DATA_PATH, train=False, download=download, transform=test_transform)

elif dataset == 'cifar100':
image_size = (32, 32, 3)
n_classes = 100
train_set = datasets.CIFAR100(DATA_PATH, train=True, download=download, transform=train_transform)
test_set = datasets.CIFAR100(DATA_PATH, train=False, download=download, transform=test_transform)

elif dataset == 'svhn':
assert test_only and image_size is not None
test_set = datasets.SVHN(DATA_PATH, split='test', download=download, transform=test_transform)

elif dataset == 'lsun_resize':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'LSUN_resize')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)

elif dataset == 'lsun_fix':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'LSUN_fix')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)

elif dataset == 'imagenet_resize':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'Imagenet_resize')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)

elif dataset == 'imagenet_fix':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'Imagenet_fix')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)

elif dataset == 'imagenet':
image_size = (224, 224, 3)
n_classes = 30
train_dir = os.path.join(IMAGENET_PATH, 'one_class_train')
test_dir = os.path.join(IMAGENET_PATH, 'one_class_test')
train_set = datasets.ImageFolder(train_dir, transform=train_transform)
test_set = datasets.ImageFolder(test_dir, transform=test_transform)

elif dataset == 'stanford_dogs':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'stanford_dogs')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)

elif dataset == 'cub':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'cub200')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)

elif dataset == 'flowers102':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'flowers102')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)

elif dataset == 'places365':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'places365')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)

elif dataset == 'food_101':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'food-101', 'images')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)

elif dataset == 'caltech_256':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'caltech-256')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)

elif dataset == 'dtd':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'dtd', 'images')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)

elif dataset == 'pets':
assert test_only and image_size is not None
test_dir = os.path.join(DATA_PATH, 'pets')
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)

else:
raise NotImplementedError()

if test_only:
return test_set
else:
return train_set, test_set, image_size, n_classes


def get_superclass_list(dataset):
if dataset == 'CNMC':
return CNMC_SUPERCLASS
if dataset == 'CNMC_grayscale':
return CNMC_SUPERCLASS
elif dataset == 'cifar10':
return CIFAR10_SUPERCLASS
elif dataset == 'cifar100':
return CIFAR100_SUPERCLASS
elif dataset == 'imagenet':
return IMAGENET_SUPERCLASS
else:
raise NotImplementedError()


def get_subclass_dataset(dataset, classes):
if not isinstance(classes, list):
classes = [classes]

indices = []
for idx, tgt in enumerate(dataset.targets):
if tgt in classes:
indices.append(idx)

dataset = Subset(dataset, indices)
return dataset


def get_simclr_eval_transform_imagenet(sample_num, resize_factor, resize_fix):

resize_scale = (resize_factor, 1.0) # resize scaling factor
if resize_fix: # if resize_fix is True, use same scale
resize_scale = (resize_factor, resize_factor)

transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224, scale=resize_scale),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])

clean_trasform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
])

transform = MultiDataTransformList(transform, clean_trasform, sample_num)

return transform, transform

def get_simclr_eval_transform_cnmc(sample_num, resize_factor, resize_fix, res, center_crop_size):

resize_scale = (resize_factor, 1.0) # resize scaling factor
if resize_fix: # if resize_fix is True, use same scale
resize_scale = (resize_factor, resize_factor)
transform = transforms.Compose([
transforms.Resize(res),
transforms.CenterCrop(center_crop_size),
transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])

clean_trasform = transforms.Compose([
transforms.Resize(res),
transforms.CenterCrop(center_crop_size),
transforms.ToTensor(),
])
transform = MultiDataTransformList(transform, clean_trasform, sample_num)

return transform, transform



+ 66
- 0
datasets/imagenet_fix_preprocess.py View File

@@ -0,0 +1,66 @@
import os
import time
import random

import cv2
import numpy as np
import torch

import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from datasets import get_subclass_dataset

def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

IMAGENET_PATH = '~/data/ImageNet'


check = time.time()

transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.Resize(32),
transforms.ToTensor(),
])

# remove airliner(1), ambulance(2), parking_meter(18), schooner(22) since similar class exist in CIFAR-10
class_idx_list = list(range(30))
remove_idx_list = [1, 2, 18, 22]
for remove_idx in remove_idx_list:
class_idx_list.remove(remove_idx)

set_random_seed(0)
train_dir = os.path.join(IMAGENET_PATH, 'one_class_train')
Imagenet_set = datasets.ImageFolder(train_dir, transform=transform)
Imagenet_set = get_subclass_dataset(Imagenet_set, class_idx_list)
Imagenet_dataloader = DataLoader(Imagenet_set, batch_size=100, shuffle=True, pin_memory=False)

total_test_image = None
for n, (test_image, target) in enumerate(Imagenet_dataloader):

if n == 0:
total_test_image = test_image
else:
total_test_image = torch.cat((total_test_image, test_image), dim=0).cpu()

if total_test_image.size(0) >= 10000:
break

print (f'Preprocessing time {time.time()-check}')

if not os.path.exists('./Imagenet_fix'):
os.mkdir('./Imagenet_fix')

check = time.time()
for i in range(10000):
save_image(total_test_image[i], f'Imagenet_fix/correct_resize_{i}.png')
print (f'Saving time {time.time()-check}')


+ 61
- 0
datasets/lsun_fix_preprocess.py View File

@@ -0,0 +1,61 @@
import os
import time
import random

import numpy as np
import torch

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image

def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

check = time.time()

transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.Resize(32),
transforms.ToTensor(),
])

set_random_seed(0)

LSUN_class_list = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen', 'living_room', 'restaurant', 'tower']

total_test_image_all_class = []
for LSUN_class in LSUN_class_list:
LSUN_set = datasets.LSUN('~/data/lsun/', classes=LSUN_class + '_train', transform=transform)
LSUN_loader = DataLoader(LSUN_set, batch_size=100, shuffle=True, pin_memory=False)

total_test_image = None
for n, (test_image, _) in enumerate(LSUN_loader):

if n == 0:
total_test_image = test_image
else:
total_test_image = torch.cat((total_test_image, test_image), dim=0).cpu()

if total_test_image.size(0) >= 1000:
break

total_test_image_all_class.append(total_test_image)

total_test_image_all_class = torch.cat(total_test_image_all_class, dim=0)

print (f'Preprocessing time {time.time()-check}')

if not os.path.exists('./LSUN_fix'):
os.mkdir('./LSUN_fix')

check = time.time()
for i in range(10000):
save_image(total_test_image_all_class[i], f'LSUN_fix/correct_resize_{i}.png')
print (f'Saving time {time.time()-check}')


+ 37
- 0
datasets/postprocess_data.py View File

@@ -0,0 +1,37 @@
import re
import matplotlib.pyplot as plt

PATH = r'C:\Users\feokt\PycharmProjects\CSI\CSI\logs'


def postprocess_data(log: list):
for pth in log:
loss_sim = []
loss_shift = []
with open(PATH + pth) as f:
lines = f.readlines()
for line in lines:
# line = '[2022-01-31 20:40:23.947855] [DONE] [Time 0.179] [Data 0.583] [LossC 0.000000] [LossSim 4.024234] [LossShift 0.065126]'
part = re.search('\[DONE\]', line)
if part is not None:
l_sim = re.search('(\[LossSim.[0-9]*.[0-9]*\])', line).group()
if l_sim is not None:
loss_sim.append(float(re.search('(\s[0-9].*[0-9])', l_sim).group()))
l_shift = re.search('(\[LossShift.[0-9]*.[0-9]*\])', line).group()
if l_shift is not None:
loss_shift.append(float(re.search('(\s[0-9].*[0-9])', l_shift).group()))
loss = [loss_sim[i] + loss_shift[i] for i in range(len(loss_sim))]

plt.ylabel("loss")
plt.xlabel("epoch")
plt.title("Loss over epochs")
plt.plot(list(range(1, 101)), loss)
for idx in range(len(log)):
log[idx] = log[idx][38:]
plt.legend(log)
plt.grid()
#plt.plot(list(range(1, 101)), loss_sim)
#plt.plot(list(range(1, 101)), loss_shift)
plt.show()



+ 196
- 0
datasets/prepare_data.py View File

@@ -0,0 +1,196 @@
import csv
import os
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
import torch


def transform_image(img_in, target_dir, transformation, suffix):
"""
Transforms an image according to provided transformation.

Parameters:
img_in (path): Image to transform
target_dir (path): Destination path
transformation (callable): Transformation to be applied
suffix (str): Suffix of resulting image.

Returns:
binary_sum (str): Binary string of the sum of a and b
"""
if suffix == 'rot':
im = Image.open(img_in)
im = im.rotate(270)
tensor = transforms.ToTensor()(im)
save_image(tensor, target_dir + os.sep + suffix + '.jpg')
elif suffix == 'sobel':
im = Image.open(img_in)
tensor = transforms.ToTensor()(im)
sobel_filter = torch.tensor([[1., 2., 1.], [0., 0., 0.], [-1., -2., -1.]])
f = sobel_filter.expand(1, 3, 3, 3)
tensor = torch.conv2d(tensor, f, stride=1, padding=1 )
save_image(tensor, target_dir + os.sep + suffix + '.jpg')
elif suffix == 'noise':
im = Image.open(img_in)
tensor = transforms.ToTensor()(im)
tensor = tensor + (torch.randn(tensor.size()) * 0.2 + 0)
save_image(tensor, target_dir + os.sep + suffix + '.jpg')
elif suffix == 'cutout':
print("asd")
else:
im = Image.open(img_in)
im_trans = transformation(im)
im_trans.save(target_dir + os.sep + suffix + '.jpg')


def sort_and_rename_images(excel_path: str):
"""Renames images and sorts them according to csv."""
base_dir = excel_path.rsplit(os.sep, 1)[0]
dir_all = base_dir + os.sep + 'all'
if not os.path.isdir(dir_all):
os.mkdir(dir_all)
dir_hem = base_dir + os.sep + 'hem'
if not os.path.isdir(dir_hem):
os.mkdir(dir_hem)

with open(excel_path, mode='r') as file:
csv_file = csv.reader(file)
for lines in csv_file:
print(lines)
if lines[2] == '1':
os.rename(base_dir + os.sep + lines[1], dir_all + os.sep + lines[0])
elif lines[2] == '0':
os.rename(base_dir + os.sep + lines[1], dir_hem + os.sep + lines[0])


def drop_color_channels(source_dir, target_dir, rgb):
"""Rotates all images in in source dir."""
if rgb == 0:
suffix = "red_only"
drop_1 = 1
drop_2 = 2
elif rgb == 1:
suffix = "green_only"
drop_1 = 0
drop_2 = 2
elif rgb == 2:
suffix = "blue_only"
drop_1 = 0
drop_2 = 1
elif rgb == 3:
suffix = "no_red"
drop_1 = 0
elif rgb == 4:
suffix = "no_green"
drop_1 = 1
elif rgb == 5:
suffix = "no_blue"
drop_1 = 2
else:
suffix = ""
print("Invalid RGB-channel")
if suffix != "":
dirs = os.listdir(source_dir)
for item in dirs:
if os.path.isfile(source_dir + os.sep + item):
im = Image.open(source_dir + os.sep + item)
tensor = transforms.ToTensor()(im)
tensor[drop_1, :, :] = 0
if rgb < 3:
tensor[drop_2, :, :] = 0
save_image(tensor, target_dir + os.sep + item, 'bmp')


def rotate_images(target_dir, source_dir, rotate, theta):
"""Rotates all images in in source dir."""
dirs = os.listdir(source_dir)
for item in dirs:
if os.path.isfile(source_dir + os.sep + item):
for i in range(0, rotate):
im = Image.open(source_dir + os.sep + item)
im = im.rotate(i*theta)
tensor = transforms.ToTensor()(im)
save_image(tensor, target_dir + os.sep + str(i) + '_' + item, 'bmp')


def grayscale_image(source_dir, target_dir):
"""Grayscale transforms all images in path."""
t = transforms.Grayscale()
dirs = os.listdir(source_dir)
if not os.path.isdir(target_dir):
os.mkdir(target_dir)
for item in dirs:
if os.path.isfile(source_dir + os.sep + item):
im = Image.open(source_dir + os.sep + item).convert('RGB')
im_resize = t(im)
tensor = transforms.ToTensor()(im_resize)
padding = torch.zeros(1, tensor.shape[1], tensor.shape[2])
tensor = torch.cat((tensor, padding), 0)
im_resize.save(target_dir + os.sep + item, 'bmp')


def resize(source_dir):
"""Rotates all images in in source dir."""
t = transforms.Compose([transforms.Resize((128, 128))])
dirs = os.listdir(source_dir)
target_dir = source_dir + os.sep + 'resized'
if not os.path.isdir(target_dir):
os.mkdir(target_dir)
for item in dirs:
if os.path.isfile(source_dir + os.sep + item):
im = Image.open(source_dir + os.sep + item)
im_resize = t(im)
im_resize.save(source_dir + os.sep + 'resized' + os.sep + item, 'bmp')


def crop_image(source_dir):
"""Center Crops all images in path."""
t = transforms.CenterCrop((224, 224))
dirs = os.listdir(source_dir)
target_dir = source_dir + os.sep + 'cropped'
if not os.path.isdir(target_dir):
os.mkdir(target_dir)
for item in dirs:
if os.path.isfile(source_dir + os.sep + item):
im = Image.open(source_dir + os.sep + item)
im_resize = t(im, )
im_resize.save(source_dir + os.sep + 'cropped' + os.sep + item, 'bmp')


def mk_dirs(target_dir):
dir_0 = target_dir + r"\fold_0"
dir_1 = target_dir + r"\fold_1"
dir_2 = target_dir + r"\fold_2"
dir_3 = target_dir + r"\phase2"
dir_4 = target_dir + r"\phase3"
dir_0_all = dir_0 + r"\all"
dir_0_hem = dir_0 + r"\hem"
dir_1_all = dir_1 + r"\all"
dir_1_hem = dir_1 + r"\hem"
dir_2_all = dir_2 + r"\all"
dir_2_hem = dir_2 + r"\hem"
if not os.path.isdir(dir_0):
os.mkdir(dir_0)
if not os.path.isdir(dir_1):
os.mkdir(dir_1)
if not os.path.isdir(dir_2):
os.mkdir(dir_2)
if not os.path.isdir(dir_3):
os.mkdir(dir_3)
if not os.path.isdir(dir_4):
os.mkdir(dir_4)

if not os.path.isdir(dir_0_all):
os.mkdir(dir_0_all)
if not os.path.isdir(dir_0_hem):
os.mkdir(dir_0_hem)
if not os.path.isdir(dir_1_all):
os.mkdir(dir_1_all)
if not os.path.isdir(dir_1_hem):
os.mkdir(dir_1_hem)
if not os.path.isdir(dir_2_all):
os.mkdir(dir_2_all)
if not os.path.isdir(dir_2_hem):
os.mkdir(dir_2_hem)
return dir_0_all, dir_0_hem, dir_1_all, dir_1_hem, dir_2_all, dir_2_hem, dir_3, dir_4

+ 4691
- 0
eval.ipynb
File diff suppressed because it is too large
View File


+ 57
- 0
eval.py View File

@@ -0,0 +1,57 @@
from common.eval import *


def main():
model.eval()

if P.mode == 'test_acc':
from evals import test_classifier
with torch.no_grad():
error = test_classifier(P, model, test_loader, 0, logger=None)

elif P.mode == 'test_marginalized_acc':
from evals import test_classifier
with torch.no_grad():
error = test_classifier(P, model, test_loader, 0, marginal=True, logger=None)

elif P.mode in ['ood', 'ood_pre']:
if P.mode == 'ood':
from evals import eval_ood_detection
else:
from evals.ood_pre import eval_ood_detection

with torch.no_grad():
auroc_dict = eval_ood_detection(P, model, test_loader, ood_test_loader, P.ood_score,
train_loader=train_loader, simclr_aug=simclr_aug)

if P.one_class_idx is not None:
mean_dict = dict()
for ood_score in P.ood_score:
mean = 0
for ood in auroc_dict.keys():
mean += auroc_dict[ood][ood_score]
mean_dict[ood_score] = mean / len(auroc_dict.keys())
auroc_dict['one_class_mean'] = mean_dict

bests = []
for ood in auroc_dict.keys():
message = ''
best_auroc = 0
for ood_score, auroc in auroc_dict[ood].items():
message += '[%s %s %.4f] ' % (ood, ood_score, auroc)
if auroc > best_auroc:
best_auroc = auroc
message += '[%s %s %.4f] ' % (ood, 'best', best_auroc)
if P.print_score:
print(message)
bests.append(best_auroc)

bests = map('{:.4f}'.format, bests)
print('\t'.join(bests))

else:
raise NotImplementedError()


if __name__ == '__main__':
main()

+ 1
- 0
evals/__init__.py View File

@@ -0,0 +1 @@
from evals.evals import test_classifier, eval_ood_detection

BIN
evals/__pycache__/__init__.cpython-36.pyc View File


BIN
evals/__pycache__/__init__.cpython-37.pyc View File


BIN
evals/__pycache__/evals.cpython-36.pyc View File


BIN
evals/__pycache__/evals.cpython-37.pyc View File


BIN
evals/__pycache__/ood_pre.cpython-36.pyc View File


BIN
evals/__pycache__/ood_pre.cpython-37.pyc View File


+ 201
- 0
evals/evals.py View File

@@ -0,0 +1,201 @@
import time
import itertools

import diffdist.functional as distops
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

import models.transform_layers as TL
from utils.temperature_scaling import _ECELoss
from utils.utils import AverageMeter, set_random_seed, normalize

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ece_criterion = _ECELoss().to(device)


def error_k(output, target, ks=(1,)):
"""Computes the precision@k for the specified values of k"""
max_k = max(ks)
batch_size = target.size(0)

_, pred = output.topk(max_k, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))

results = []
for k in ks:
correct_k = correct[:k].view(-1).float().sum(0)
results.append(100.0 - correct_k.mul_(100.0 / batch_size))
return results


def test_classifier(P, model, loader, steps, marginal=False, logger=None):
error_top1 = AverageMeter()
error_calibration = AverageMeter()

if logger is None:
log_ = print
else:
log_ = logger.log

# Switch to evaluate mode
mode = model.training
model.eval()

for n, (images, labels) in enumerate(loader):
batch_size = images.size(0)

images, labels = images.to(device), labels.to(device)

if marginal:
outputs = 0
for i in range(4):
rot_images = torch.rot90(images, i, (2, 3))
_, outputs_aux = model(rot_images, joint=True)
outputs += outputs_aux['joint'][:, P.n_classes * i: P.n_classes * (i + 1)] / 4.
else:
outputs = model(images)

top1, = error_k(outputs.data, labels, ks=(1,))
error_top1.update(top1.item(), batch_size)

ece = ece_criterion(outputs, labels) * 100
error_calibration.update(ece.item(), batch_size)

if n % 100 == 0:
log_('[Test %3d] [Test@1 %.3f] [ECE %.3f]' %
(n, error_top1.value, error_calibration.value))

log_(' * [Error@1 %.3f] [ECE %.3f]' %
(error_top1.average, error_calibration.average))

if logger is not None:
logger.scalar_summary('eval/clean_error', error_top1.average, steps)
logger.scalar_summary('eval/ece', error_calibration.average, steps)

model.train(mode)

return error_top1.average


def eval_ood_detection(P, model, id_loader, ood_loaders, ood_scores, train_loader=None, simclr_aug=None):
auroc_dict = dict()
for ood in ood_loaders.keys():
auroc_dict[ood] = dict()

for ood_score in ood_scores:
# compute scores for ID and OOD samples
score_func = get_ood_score_func(P, model, ood_score, simclr_aug=simclr_aug)

save_path = f'plot/score_in_{P.dataset}_{ood_score}'
if P.one_class_idx is not None:
save_path += f'_{P.one_class_idx}'

scores_id = get_scores(id_loader, score_func)

if P.save_score:
np.save(f'{save_path}.npy', scores_id)

for ood, ood_loader in ood_loaders.items():
if ood == 'interp':
scores_ood = get_scores_interp(id_loader, score_func)
auroc_dict['interp'][ood_score] = get_auroc(scores_id, scores_ood)
else:
scores_ood = get_scores(ood_loader, score_func)
auroc_dict[ood][ood_score] = get_auroc(scores_id, scores_ood)

if P.save_score:
np.save(f'{save_path}_out_{ood}.npy', scores_ood)

return auroc_dict


def get_ood_score_func(P, model, ood_score, simclr_aug=None):
def score_func(x):
return compute_ood_score(P, model, ood_score, x, simclr_aug=simclr_aug)
return score_func


def get_scores(loader, score_func):
scores = []
for i, (x, _) in enumerate(loader):
s = score_func(x.to(device))
assert s.dim() == 1 and s.size(0) == x.size(0)

scores.append(s.detach().cpu().numpy())
return np.concatenate(scores)


def get_scores_interp(loader, score_func):
scores = []
for i, (x, _) in enumerate(loader):
x_interp = (x + last) / 2 if i > 0 else x # omit the first batch, assume batch sizes are equal
last = x # save the last batch
s = score_func(x_interp.to(device))
assert s.dim() == 1 and s.size(0) == x.size(0)

scores.append(s.detach().cpu().numpy())
return np.concatenate(scores)


def get_auroc(scores_id, scores_ood):
scores = np.concatenate([scores_id, scores_ood])
labels = np.concatenate([np.ones_like(scores_id), np.zeros_like(scores_ood)])
return roc_auc_score(labels, scores)


def compute_ood_score(P, model, ood_score, x, simclr_aug=None):
model.eval()

if ood_score == 'clean_norm':
_, output_aux = model(x, penultimate=True, simclr=True)
score = output_aux[P.ood_layer].norm(dim=1)
return score

elif ood_score == 'similar':
assert simclr_aug is not None # require custom simclr augmentation
sample_num = 2 # fast evaluation
feats = get_features(model, simclr_aug, x, layer=P.ood_layer, sample_num=sample_num)
feats_avg = sum(feats) / len(feats)

scores = []
for seed in range(sample_num):
sim = torch.cosine_similarity(feats[seed], feats_avg)
scores.append(sim)
return sum(scores) / len(scores)

elif ood_score == 'baseline':
outputs, outputs_aux = model(x, penultimate=True)
scores = F.softmax(outputs, dim=1).max(dim=1)[0]
return scores

elif ood_score == 'baseline_marginalized':

total_outputs = 0
for i in range(4):
x_rot = torch.rot90(x, i, (2, 3))
outputs, outputs_aux = model(x_rot, penultimate=True, joint=True)
total_outputs += outputs_aux['joint'][:, P.n_classes * i:P.n_classes * (i + 1)]

scores = F.softmax(total_outputs / 4., dim=1).max(dim=1)[0]
return scores

else:
raise NotImplementedError()


def get_features(model, simclr_aug, x, layer='simclr', sample_num=1):
model.eval()

feats = []
for seed in range(sample_num):
set_random_seed(seed)
x_t = simclr_aug(x)
with torch.no_grad():
_, output_aux = model(x_t, penultimate=True, simclr=True, shift=True)
feats.append(output_aux[layer])
return feats

+ 242
- 0
evals/ood_pre.py View File

@@ -0,0 +1,242 @@
import os
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import models.transform_layers as TL
from utils.utils import set_random_seed, normalize
from evals.evals import get_auroc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hflip = TL.HorizontalFlipLayer().to(device)


def eval_ood_detection(P, model, id_loader, ood_loaders, ood_scores, train_loader=None, simclr_aug=None):
auroc_dict = dict()
for ood in ood_loaders.keys():
auroc_dict[ood] = dict()

assert len(ood_scores) == 1 # assume single ood_score for simplicity
ood_score = ood_scores[0]

base_path = os.path.split(P.load_path)[0] # checkpoint directory

prefix = f'{P.ood_samples}'
if P.resize_fix:
prefix += f'_resize_fix_{P.resize_factor}'
else:
prefix += f'_resize_range_{P.resize_factor}'

prefix = os.path.join(base_path, f'feats_{prefix}')

kwargs = {
'simclr_aug': simclr_aug,
'sample_num': P.ood_samples,
'layers': P.ood_layer,
}

print('Pre-compute global statistics...')
feats_train = get_features(P, f'{P.dataset}_train', model, train_loader, prefix=prefix, **kwargs) # (M, T, d)

P.axis = []
for f in feats_train['simclr'].chunk(P.K_shift, dim=1):
axis = f.mean(dim=1) # (M, d)
P.axis.append(normalize(axis, dim=1).to(device))
print('axis size: ' + ' '.join(map(lambda x: str(len(x)), P.axis)))

f_sim = [f.mean(dim=1) for f in feats_train['simclr'].chunk(P.K_shift, dim=1)] # list of (M, d)
f_shi = [f.mean(dim=1) for f in feats_train['shift'].chunk(P.K_shift, dim=1)] # list of (M, 4)

weight_sim = []
weight_shi = []
for shi in range(P.K_shift):
sim_norm = f_sim[shi].norm(dim=1) # (M)
shi_mean = f_shi[shi][:, shi] # (M)
weight_sim.append(1 / sim_norm.mean().item())
weight_shi.append(1 / shi_mean.mean().item())

if ood_score == 'simclr':
P.weight_sim = [1]
P.weight_shi = [0]
elif ood_score == 'CSI':
P.weight_sim = weight_sim
P.weight_shi = weight_shi
else:
raise ValueError()

print(f'weight_sim:\t' + '\t'.join(map('{:.4f}'.format, P.weight_sim)))
print(f'weight_shi:\t' + '\t'.join(map('{:.4f}'.format, P.weight_shi)))

print('Pre-compute features...')
feats_id = get_features(P, P.dataset, model, id_loader, prefix=prefix, **kwargs) # (N, T, d)
feats_ood = dict()
for ood, ood_loader in ood_loaders.items():
if ood == 'interp':
feats_ood[ood] = get_features(P, ood, model, id_loader, interp=True, prefix=prefix, **kwargs)
else:
feats_ood[ood] = get_features(P, ood, model, ood_loader, prefix=prefix, **kwargs)
print(f'Compute OOD scores... (score: {ood_score})')
scores_id = get_scores(P, feats_id, ood_score).numpy()
scores_ood = dict()
if P.one_class_idx is not None:
one_class_score = []

for ood, feats in feats_ood.items():
scores_ood[ood] = get_scores(P, feats, ood_score).numpy()
auroc_dict[ood][ood_score] = get_auroc(scores_id, scores_ood[ood])
if P.one_class_idx is not None:
one_class_score.append(scores_ood[ood])

if P.one_class_idx is not None:
one_class_score = np.concatenate(one_class_score)
one_class_total = get_auroc(scores_id, one_class_score)
print(f'One_class_real_mean: {one_class_total}')

if P.print_score:
print_score(P.dataset, scores_id)
for ood, scores in scores_ood.items():
print_score(ood, scores)

return auroc_dict


def get_scores(P, feats_dict, ood_score):
# convert to gpu tensor
feats_sim = feats_dict['simclr'].to(device)
feats_shi = feats_dict['shift'].to(device)
N = feats_sim.size(0)

# compute scores
scores = []
for f_sim, f_shi in zip(feats_sim, feats_shi):
f_sim = [f.mean(dim=0, keepdim=True) for f in f_sim.chunk(P.K_shift)] # list of (1, d)
f_shi = [f.mean(dim=0, keepdim=True) for f in f_shi.chunk(P.K_shift)] # list of (1, 4)
score = 0
for shi in range(P.K_shift):
score += (f_sim[shi] * P.axis[shi]).sum(dim=1).max().item() * P.weight_sim[shi]
score += f_shi[shi][:, shi].item() * P.weight_shi[shi]
score = score / P.K_shift
scores.append(score)
scores = torch.tensor(scores)

assert scores.dim() == 1 and scores.size(0) == N # (N)
return scores.cpu()


def get_features(P, data_name, model, loader, interp=False, prefix='',
simclr_aug=None, sample_num=1, layers=('simclr', 'shift')):

if not isinstance(layers, (list, tuple)):
layers = [layers]

# load pre-computed features if exists
feats_dict = dict()
# for layer in layers:
# path = prefix + f'_{data_name}_{layer}.pth'
# if os.path.exists(path):
# feats_dict[layer] = torch.load(path)

# pre-compute features and save to the path
left = [layer for layer in layers if layer not in feats_dict.keys()]
if len(left) > 0:
_feats_dict = _get_features(P, model, loader, interp, (P.dataset == 'imagenet' or
P.dataset == 'CNMC' or
P.dataset == 'CNMC_grayscale'),
simclr_aug, sample_num, layers=left)

for layer, feats in _feats_dict.items():
path = prefix + f'_{data_name}_{layer}.pth'
torch.save(_feats_dict[layer], path)
feats_dict[layer] = feats # update value

return feats_dict


def _get_features(P, model, loader, interp=False, imagenet=False, simclr_aug=None,
sample_num=1, layers=('simclr', 'shift')):

if not isinstance(layers, (list, tuple)):
layers = [layers]

# check if arguments are valid
assert simclr_aug is not None

if imagenet is True: # assume batch_size = 1 for ImageNet
sample_num = 1

# compute features in full dataset
model.eval()
feats_all = {layer: [] for layer in layers} # initialize: empty list
for i, (x, _) in enumerate(loader):
if interp:
x_interp = (x + last) / 2 if i > 0 else x # omit the first batch, assume batch sizes are equal
last = x # save the last batch
x = x_interp # use interp as current batch

if imagenet is True:
x = torch.cat(x[0], dim=0) # augmented list of x

x = x.to(device) # gpu tensor

# compute features in one batch
feats_batch = {layer: [] for layer in layers} # initialize: empty list
for seed in range(sample_num):
set_random_seed(seed)

if P.K_shift > 1:
x_t = torch.cat([P.shift_trans(hflip(x), k) for k in range(P.K_shift)])
else:
x_t = x # No shifting: SimCLR
x_t = simclr_aug(x_t)

# compute augmented features
with torch.no_grad():
kwargs = {layer: True for layer in layers} # only forward selected layers
_, output_aux = model(x_t, **kwargs)

# add features in one batch
for layer in layers:
feats = output_aux[layer].cpu()
if imagenet is False:
feats_batch[layer] += feats.chunk(P.K_shift)
else:
feats_batch[layer] += [feats] # (B, d) cpu tensor

# concatenate features in one batch
for key, val in feats_batch.items():
if imagenet:
feats_batch[key] = torch.stack(val, dim=0) # (B, T, d)
else:
feats_batch[key] = torch.stack(val, dim=1) # (B, T, d)

# add features in full dataset
for layer in layers:
feats_all[layer] += [feats_batch[layer]]

# concatenate features in full dataset
for key, val in feats_all.items():
feats_all[key] = torch.cat(val, dim=0) # (N, T, d)

# reshape order
if imagenet is False:
# Convert [1,2,3,4, 1,2,3,4] -> [1,1, 2,2, 3,3, 4,4]
for key, val in feats_all.items():
N, T, d = val.size() # T = K * T'
val = val.view(N, -1, P.K_shift, d) # (N, T', K, d)
val = val.transpose(2, 1) # (N, 4, T', d)
val = val.reshape(N, T, d) # (N, T, d)
feats_all[key] = val

return feats_all


def print_score(data_name, scores):
quantile = np.quantile(scores, np.arange(0, 1.1, 0.1))
print('{:18s} '.format(data_name) +
'{:.4f} +- {:.4f} '.format(np.mean(scores), np.std(scores)) +
' '.join(['q{:d}: {:.4f}'.format(i * 10, quantile[i]) for i in range(11)]))


BIN
figures/CSI_teaser.png View File


BIN
figures/fixed_ood_benchmarks.png View File


BIN
figures/shifting_transformations.png View File


+ 37
- 0
main.py View File

@@ -0,0 +1,37 @@
from sys import argv
from os import system
from datasets.prepare_data import prep, resize

import torch
import os
from datasets.postprocess_data import postprocess_data

DATA_BASE_DIR = r'/home/feoktistovar67431/CSI/CSI_local/main.py'
BASE_DIR = '/home/feoktistovar67431/CSI/CSI_local/'

def main():
for argument in argv:
if argument == '--proc_step':
proc_step = argv[argv.index(argument)+1]
if proc_step == 'eval':
system("eval.py "+' '.join(argv[1:]))
if proc_step == 'train':
system(BASE_DIR + os.sep + "eval.py " + ' '.join(argv[1:]))
if proc_step == 'plot':
plot_data()
elif proc_step == 'post_proc':
postprocess_data(
[
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm4_one_class_0\log.txt',
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm4_one_class_0_64px\log.txt',
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm16_one_class_0_32px\log.txt',
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm_one_class_0_64px_batch64\log.txt',
r'\CNMC_resnet18_unsup_simclr_CSI_shift_rotation_one_class_0\log.txt',
r"\CNMC_resnet18_unsup_simclr_CSI_shift_gauss_one_class_0_32px\log.txt"
# r'\cifar10_resnet18_unsup_simclr_CSI_shift_rotation_one_class_1\log.txt'
]
)


if __name__ == '__main__':
main()

+ 0
- 0
models/__init__.py View File


BIN
models/__pycache__/__init__.cpython-36.pyc View File


BIN
models/__pycache__/__init__.cpython-37.pyc View File


BIN
models/__pycache__/base_model.cpython-36.pyc View File


BIN
models/__pycache__/base_model.cpython-37.pyc View File


BIN
models/__pycache__/classifier.cpython-36.pyc View File


BIN
models/__pycache__/classifier.cpython-37.pyc View File


BIN
models/__pycache__/resnet.cpython-36.pyc View File


BIN
models/__pycache__/resnet.cpython-37.pyc View File


BIN
models/__pycache__/resnet_imagenet.cpython-36.pyc View File


BIN
models/__pycache__/resnet_imagenet.cpython-37.pyc View File


BIN
models/__pycache__/transform_layers.cpython-36.pyc View File


BIN
models/__pycache__/transform_layers.cpython-37.pyc View File


+ 48
- 0
models/base_model.py View File

@@ -0,0 +1,48 @@
from abc import *
import torch.nn as nn


class BaseModel(nn.Module, metaclass=ABCMeta):
def __init__(self, last_dim, num_classes=10, simclr_dim=128):
super(BaseModel, self).__init__()
self.linear = nn.Linear(last_dim, num_classes)
self.simclr_layer = nn.Sequential(
nn.Linear(last_dim, last_dim),
nn.ReLU(),
nn.Linear(last_dim, simclr_dim),
)
self.shift_cls_layer = nn.Linear(last_dim, 2)
self.joint_distribution_layer = nn.Linear(last_dim, 4 * num_classes)

@abstractmethod
def penultimate(self, inputs, all_features=False):
pass

def forward(self, inputs, penultimate=False, simclr=False, shift=False, joint=False):
_aux = {}
_return_aux = False

features = self.penultimate(inputs)

output = self.linear(features)

if penultimate:
_return_aux = True
_aux['penultimate'] = features

if simclr:
_return_aux = True
_aux['simclr'] = self.simclr_layer(features)

if shift:
_return_aux = True
_aux['shift'] = self.shift_cls_layer(features)

if joint:
_return_aux = True
_aux['joint'] = self.joint_distribution_layer(features)

if _return_aux:
return output, _aux

return output

+ 135
- 0
models/classifier.py View File

@@ -0,0 +1,135 @@
import torch.nn as nn

from models.resnet import ResNet18, ResNet34, ResNet50
from models.resnet_imagenet import resnet18, resnet50
import models.transform_layers as TL
from torchvision import transforms


def get_simclr_augmentation(P, image_size):
"""
Creates positive data for training.

:param P: parsed arguments
:param image_size: size of image
:return: transformation
"""

# parameter for resizecrop
resize_scale = (P.resize_factor, 1.0) # resize scaling factor
if P.resize_fix: # if resize_fix is True, use same scale
resize_scale = (P.resize_factor, P.resize_factor)

# Align augmentation
s = P.color_distort
color_jitter = TL.ColorJitterLayer(brightness=s*0.8, contrast=s*0.8, saturation=s*0.8, hue=s*0.2, p=0.8)
color_gray = TL.RandomColorGrayLayer(p=0.2)
resize_crop = TL.RandomResizedCropLayer(scale=resize_scale, size=(image_size[0], image_size[1]))
#v_flip = transforms.RandomVerticalFlip()