362 lines
13 KiB
Python
362 lines
13 KiB
Python
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data.dataset import Subset
|
|
from torchvision import datasets, transforms
|
|
|
|
from utils.utils import set_random_seed
|
|
|
|
DATA_PATH = '~/data/'
|
|
IMAGENET_PATH = '~/data/ImageNet'
|
|
CNMC_PATH = r'~/data/CSI/CNMC_orig'
|
|
CNMC_GRAY_PATH = r'~/data/CSI/CNMC_orig_gray'
|
|
CNMC_ROT4_PATH = r'~/data/CSI/CNMC_rotated_4'
|
|
|
|
CIFAR10_SUPERCLASS = list(range(10)) # one class
|
|
IMAGENET_SUPERCLASS = list(range(30)) # one class
|
|
CNMC_SUPERCLASS = list(range(2)) # one class
|
|
|
|
STD_RES = 450
|
|
STD_CENTER_CROP = 300
|
|
|
|
CIFAR100_SUPERCLASS = [
|
|
[4, 31, 55, 72, 95],
|
|
[1, 33, 67, 73, 91],
|
|
[54, 62, 70, 82, 92],
|
|
[9, 10, 16, 29, 61],
|
|
[0, 51, 53, 57, 83],
|
|
[22, 25, 40, 86, 87],
|
|
[5, 20, 26, 84, 94],
|
|
[6, 7, 14, 18, 24],
|
|
[3, 42, 43, 88, 97],
|
|
[12, 17, 38, 68, 76],
|
|
[23, 34, 49, 60, 71],
|
|
[15, 19, 21, 32, 39],
|
|
[35, 63, 64, 66, 75],
|
|
[27, 45, 77, 79, 99],
|
|
[2, 11, 36, 46, 98],
|
|
[28, 30, 44, 78, 93],
|
|
[37, 50, 65, 74, 80],
|
|
[47, 52, 56, 59, 96],
|
|
[8, 13, 48, 58, 90],
|
|
[41, 69, 81, 85, 89],
|
|
]
|
|
|
|
|
|
class MultiDataTransform(object):
|
|
def __init__(self, transform):
|
|
self.transform1 = transform
|
|
self.transform2 = transform
|
|
|
|
def __call__(self, sample):
|
|
x1 = self.transform1(sample)
|
|
x2 = self.transform2(sample)
|
|
return x1, x2
|
|
|
|
|
|
class MultiDataTransformList(object):
|
|
def __init__(self, transform, clean_trasform, sample_num):
|
|
self.transform = transform
|
|
self.clean_transform = clean_trasform
|
|
self.sample_num = sample_num
|
|
|
|
def __call__(self, sample):
|
|
set_random_seed(0)
|
|
|
|
sample_list = []
|
|
for i in range(self.sample_num):
|
|
sample_list.append(self.transform(sample))
|
|
|
|
return sample_list, self.clean_transform(sample)
|
|
|
|
|
|
def get_transform(image_size=None):
|
|
# Note: data augmentation is implemented in the layers
|
|
# Hence, we only define the identity transformation here
|
|
if image_size: # use pre-specified image size
|
|
train_transform = transforms.Compose([
|
|
transforms.Resize((image_size[0], image_size[1])),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
])
|
|
test_transform = transforms.Compose([
|
|
transforms.Resize((image_size[0], image_size[1])),
|
|
transforms.ToTensor(),
|
|
])
|
|
else: # use default image size
|
|
train_transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
])
|
|
test_transform = transforms.ToTensor()
|
|
|
|
return train_transform, test_transform
|
|
|
|
|
|
def get_subset_with_len(dataset, length, shuffle=False):
|
|
set_random_seed(0)
|
|
dataset_size = len(dataset)
|
|
|
|
index = np.arange(dataset_size)
|
|
if shuffle:
|
|
np.random.shuffle(index)
|
|
|
|
index = torch.from_numpy(index[0:length])
|
|
subset = Subset(dataset, index)
|
|
|
|
assert len(subset) == length
|
|
|
|
return subset
|
|
|
|
|
|
def get_transform_imagenet():
|
|
|
|
train_transform = transforms.Compose([
|
|
transforms.Resize(256),
|
|
transforms.RandomResizedCrop(224),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
])
|
|
test_transform = transforms.Compose([
|
|
transforms.Resize(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
train_transform = MultiDataTransform(train_transform)
|
|
|
|
return train_transform, test_transform
|
|
|
|
def get_transform_cnmc(res, center_crop_size):
|
|
train_transform = transforms.Compose([
|
|
transforms.Resize(res),
|
|
transforms.CenterCrop(center_crop_size),
|
|
transforms.RandomVerticalFlip(),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
])
|
|
test_transform = transforms.Compose([
|
|
transforms.Resize(res),
|
|
transforms.CenterCrop(center_crop_size),
|
|
transforms.ToTensor(),
|
|
])
|
|
train_transform = MultiDataTransform(train_transform)
|
|
|
|
return train_transform, test_transform
|
|
|
|
|
|
def get_dataset(P, dataset, test_only=False, image_size=None, download=False, eval=False):
|
|
if P.res != '':
|
|
res = int(P.res.replace('px', ''))
|
|
size_factor = int(STD_RES/res) # always remove same portion
|
|
center_crop_size = int(STD_CENTER_CROP/size_factor) # remove black border
|
|
if dataset in ['CNMC', 'CNMC_grayscale', 'CNMC_ROT4_PATH']:
|
|
if eval:
|
|
train_transform, test_transform = get_simclr_eval_transform_cnmc(P.ood_samples,
|
|
P.resize_factor, P.resize_fix, res, center_crop_size)
|
|
else:
|
|
train_transform, test_transform = get_transform_cnmc(res, center_crop_size)
|
|
elif dataset in ['imagenet', 'cub', 'stanford_dogs', 'flowers102',
|
|
'places365', 'food_101', 'caltech_256', 'dtd', 'pets']:
|
|
if eval:
|
|
train_transform, test_transform = get_simclr_eval_transform_imagenet(P.ood_samples,
|
|
P.resize_factor, P.resize_fix)
|
|
else:
|
|
train_transform, test_transform = get_transform_imagenet()
|
|
else:
|
|
train_transform, test_transform = get_transform(image_size=image_size)
|
|
|
|
if dataset == 'CNMC':
|
|
image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3
|
|
n_classes = 2
|
|
train_dir = os.path.join(CNMC_PATH, '0_training')
|
|
test_dir = os.path.join(CNMC_PATH, '1_validation')
|
|
train_set = datasets.ImageFolder(train_dir, transform=train_transform)
|
|
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
|
|
|
|
elif dataset == 'CNMC_grayscale':
|
|
image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3
|
|
n_classes = 2
|
|
train_dir = os.path.join(CNMC_GRAY_PATH, '0_training')
|
|
test_dir = os.path.join(CNMC_GRAY_PATH, '1_validation')
|
|
train_set = datasets.ImageFolder(train_dir, transform=train_transform)
|
|
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
|
|
|
|
elif dataset == 'cifar10':
|
|
image_size = (32, 32, 3)
|
|
n_classes = 10
|
|
train_set = datasets.CIFAR10(DATA_PATH, train=True, download=download, transform=train_transform)
|
|
test_set = datasets.CIFAR10(DATA_PATH, train=False, download=download, transform=test_transform)
|
|
|
|
elif dataset == 'cifar100':
|
|
image_size = (32, 32, 3)
|
|
n_classes = 100
|
|
train_set = datasets.CIFAR100(DATA_PATH, train=True, download=download, transform=train_transform)
|
|
test_set = datasets.CIFAR100(DATA_PATH, train=False, download=download, transform=test_transform)
|
|
|
|
elif dataset == 'svhn':
|
|
assert test_only and image_size is not None
|
|
test_set = datasets.SVHN(DATA_PATH, split='test', download=download, transform=test_transform)
|
|
|
|
elif dataset == 'lsun_resize':
|
|
assert test_only and image_size is not None
|
|
test_dir = os.path.join(DATA_PATH, 'LSUN_resize')
|
|
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
|
|
|
|
elif dataset == 'lsun_fix':
|
|
assert test_only and image_size is not None
|
|
test_dir = os.path.join(DATA_PATH, 'LSUN_fix')
|
|
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
|
|
|
|
elif dataset == 'imagenet_resize':
|
|
assert test_only and image_size is not None
|
|
test_dir = os.path.join(DATA_PATH, 'Imagenet_resize')
|
|
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
|
|
|
|
elif dataset == 'imagenet_fix':
|
|
assert test_only and image_size is not None
|
|
test_dir = os.path.join(DATA_PATH, 'Imagenet_fix')
|
|
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
|
|
|
|
elif dataset == 'imagenet':
|
|
image_size = (224, 224, 3)
|
|
n_classes = 30
|
|
train_dir = os.path.join(IMAGENET_PATH, 'one_class_train')
|
|
test_dir = os.path.join(IMAGENET_PATH, 'one_class_test')
|
|
train_set = datasets.ImageFolder(train_dir, transform=train_transform)
|
|
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
|
|
|
|
elif dataset == 'stanford_dogs':
|
|
assert test_only and image_size is not None
|
|
test_dir = os.path.join(DATA_PATH, 'stanford_dogs')
|
|
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
|
|
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
|
|
|
|
elif dataset == 'cub':
|
|
assert test_only and image_size is not None
|
|
test_dir = os.path.join(DATA_PATH, 'cub200')
|
|
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
|
|
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
|
|
|
|
elif dataset == 'flowers102':
|
|
assert test_only and image_size is not None
|
|
test_dir = os.path.join(DATA_PATH, 'flowers102')
|
|
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
|
|
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
|
|
|
|
elif dataset == 'places365':
|
|
assert test_only and image_size is not None
|
|
test_dir = os.path.join(DATA_PATH, 'places365')
|
|
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
|
|
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
|
|
|
|
elif dataset == 'food_101':
|
|
assert test_only and image_size is not None
|
|
test_dir = os.path.join(DATA_PATH, 'food-101', 'images')
|
|
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
|
|
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
|
|
|
|
elif dataset == 'caltech_256':
|
|
assert test_only and image_size is not None
|
|
test_dir = os.path.join(DATA_PATH, 'caltech-256')
|
|
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
|
|
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
|
|
|
|
elif dataset == 'dtd':
|
|
assert test_only and image_size is not None
|
|
test_dir = os.path.join(DATA_PATH, 'dtd', 'images')
|
|
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
|
|
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
|
|
|
|
elif dataset == 'pets':
|
|
assert test_only and image_size is not None
|
|
test_dir = os.path.join(DATA_PATH, 'pets')
|
|
test_set = datasets.ImageFolder(test_dir, transform=test_transform)
|
|
test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
|
|
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
if test_only:
|
|
return test_set
|
|
else:
|
|
return train_set, test_set, image_size, n_classes
|
|
|
|
|
|
def get_superclass_list(dataset):
|
|
if dataset == 'CNMC':
|
|
return CNMC_SUPERCLASS
|
|
if dataset == 'CNMC_grayscale':
|
|
return CNMC_SUPERCLASS
|
|
elif dataset == 'cifar10':
|
|
return CIFAR10_SUPERCLASS
|
|
elif dataset == 'cifar100':
|
|
return CIFAR100_SUPERCLASS
|
|
elif dataset == 'imagenet':
|
|
return IMAGENET_SUPERCLASS
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
|
|
def get_subclass_dataset(dataset, classes):
|
|
if not isinstance(classes, list):
|
|
classes = [classes]
|
|
|
|
indices = []
|
|
for idx, tgt in enumerate(dataset.targets):
|
|
if tgt in classes:
|
|
indices.append(idx)
|
|
|
|
dataset = Subset(dataset, indices)
|
|
return dataset
|
|
|
|
|
|
def get_simclr_eval_transform_imagenet(sample_num, resize_factor, resize_fix):
|
|
|
|
resize_scale = (resize_factor, 1.0) # resize scaling factor
|
|
if resize_fix: # if resize_fix is True, use same scale
|
|
resize_scale = (resize_factor, resize_factor)
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize(256),
|
|
transforms.RandomResizedCrop(224, scale=resize_scale),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
clean_trasform = transforms.Compose([
|
|
transforms.Resize(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
transform = MultiDataTransformList(transform, clean_trasform, sample_num)
|
|
|
|
return transform, transform
|
|
|
|
def get_simclr_eval_transform_cnmc(sample_num, resize_factor, resize_fix, res, center_crop_size):
|
|
|
|
resize_scale = (resize_factor, 1.0) # resize scaling factor
|
|
if resize_fix: # if resize_fix is True, use same scale
|
|
resize_scale = (resize_factor, resize_factor)
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize(res),
|
|
transforms.CenterCrop(center_crop_size),
|
|
transforms.RandomVerticalFlip(),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
clean_trasform = transforms.Compose([
|
|
transforms.Resize(res),
|
|
transforms.CenterCrop(center_crop_size),
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
transform = MultiDataTransformList(transform, clean_trasform, sample_num)
|
|
|
|
return transform, transform
|
|
|
|
|