123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361 |
- import os
-
- import numpy as np
- import torch
- from torch.utils.data.dataset import Subset
- from torchvision import datasets, transforms
-
- from utils.utils import set_random_seed
-
- DATA_PATH = '~/data/'
- IMAGENET_PATH = '~/data/ImageNet'
- CNMC_PATH = r'~/data/CSI/CNMC_orig'
- CNMC_GRAY_PATH = r'~/data/CSI/CNMC_orig_gray'
- CNMC_ROT4_PATH = r'~/data/CSI/CNMC_rotated_4'
-
- CIFAR10_SUPERCLASS = list(range(10)) # one class
- IMAGENET_SUPERCLASS = list(range(30)) # one class
- CNMC_SUPERCLASS = list(range(2)) # one class
-
- STD_RES = 450
- STD_CENTER_CROP = 300
-
- CIFAR100_SUPERCLASS = [
- [4, 31, 55, 72, 95],
- [1, 33, 67, 73, 91],
- [54, 62, 70, 82, 92],
- [9, 10, 16, 29, 61],
- [0, 51, 53, 57, 83],
- [22, 25, 40, 86, 87],
- [5, 20, 26, 84, 94],
- [6, 7, 14, 18, 24],
- [3, 42, 43, 88, 97],
- [12, 17, 38, 68, 76],
- [23, 34, 49, 60, 71],
- [15, 19, 21, 32, 39],
- [35, 63, 64, 66, 75],
- [27, 45, 77, 79, 99],
- [2, 11, 36, 46, 98],
- [28, 30, 44, 78, 93],
- [37, 50, 65, 74, 80],
- [47, 52, 56, 59, 96],
- [8, 13, 48, 58, 90],
- [41, 69, 81, 85, 89],
- ]
-
-
- class MultiDataTransform(object):
- def __init__(self, transform):
- self.transform1 = transform
- self.transform2 = transform
-
- def __call__(self, sample):
- x1 = self.transform1(sample)
- x2 = self.transform2(sample)
- return x1, x2
-
-
- class MultiDataTransformList(object):
- def __init__(self, transform, clean_trasform, sample_num):
- self.transform = transform
- self.clean_transform = clean_trasform
- self.sample_num = sample_num
-
- def __call__(self, sample):
- set_random_seed(0)
-
- sample_list = []
- for i in range(self.sample_num):
- sample_list.append(self.transform(sample))
-
- return sample_list, self.clean_transform(sample)
-
-
- def get_transform(image_size=None):
- # Note: data augmentation is implemented in the layers
- # Hence, we only define the identity transformation here
- if image_size: # use pre-specified image size
- train_transform = transforms.Compose([
- transforms.Resize((image_size[0], image_size[1])),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- ])
- test_transform = transforms.Compose([
- transforms.Resize((image_size[0], image_size[1])),
- transforms.ToTensor(),
- ])
- else: # use default image size
- train_transform = transforms.Compose([
- transforms.ToTensor(),
- ])
- test_transform = transforms.ToTensor()
-
- return train_transform, test_transform
-
-
- def get_subset_with_len(dataset, length, shuffle=False):
- set_random_seed(0)
- dataset_size = len(dataset)
-
- index = np.arange(dataset_size)
- if shuffle:
- np.random.shuffle(index)
-
- index = torch.from_numpy(index[0:length])
- subset = Subset(dataset, index)
-
- assert len(subset) == length
-
- return subset
-
-
- def get_transform_imagenet():
-
- train_transform = transforms.Compose([
- transforms.Resize(256),
- transforms.RandomResizedCrop(224),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- ])
- test_transform = transforms.Compose([
- transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- ])
-
- train_transform = MultiDataTransform(train_transform)
-
- return train_transform, test_transform
-
- def get_transform_cnmc(res, center_crop_size):
- train_transform = transforms.Compose([
- transforms.Resize(res),
- transforms.CenterCrop(center_crop_size),
- transforms.RandomVerticalFlip(),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- ])
- test_transform = transforms.Compose([
- transforms.Resize(res),
- transforms.CenterCrop(center_crop_size),
- transforms.ToTensor(),
- ])
- train_transform = MultiDataTransform(train_transform)
-
- return train_transform, test_transform
-
-
- def get_dataset(P, dataset, test_only=False, image_size=None, download=False, eval=False):
- if P.res != '':
- res = int(P.res.replace('px', ''))
- size_factor = int(STD_RES/res) # always remove same portion
- center_crop_size = int(STD_CENTER_CROP/size_factor) # remove black border
- if dataset in ['CNMC', 'CNMC_grayscale', 'CNMC_ROT4_PATH']:
- if eval:
- train_transform, test_transform = get_simclr_eval_transform_cnmc(P.ood_samples,
- P.resize_factor, P.resize_fix, res, center_crop_size)
- else:
- train_transform, test_transform = get_transform_cnmc(res, center_crop_size)
- elif dataset in ['imagenet', 'cub', 'stanford_dogs', 'flowers102',
- 'places365', 'food_101', 'caltech_256', 'dtd', 'pets']:
- if eval:
- train_transform, test_transform = get_simclr_eval_transform_imagenet(P.ood_samples,
- P.resize_factor, P.resize_fix)
- else:
- train_transform, test_transform = get_transform_imagenet()
- else:
- train_transform, test_transform = get_transform(image_size=image_size)
-
- if dataset == 'CNMC':
- image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3
- n_classes = 2
- train_dir = os.path.join(CNMC_PATH, '0_training')
- test_dir = os.path.join(CNMC_PATH, '1_validation')
- train_set = datasets.ImageFolder(train_dir, transform=train_transform)
- test_set = datasets.ImageFolder(test_dir, transform=test_transform)
-
- elif dataset == 'CNMC_grayscale':
- image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3
- n_classes = 2
- train_dir = os.path.join(CNMC_GRAY_PATH, '0_training')
- test_dir = os.path.join(CNMC_GRAY_PATH, '1_validation')
- train_set = datasets.ImageFolder(train_dir, transform=train_transform)
- test_set = datasets.ImageFolder(test_dir, transform=test_transform)
-
- elif dataset == 'cifar10':
- image_size = (32, 32, 3)
- n_classes = 10
- train_set = datasets.CIFAR10(DATA_PATH, train=True, download=download, transform=train_transform)
- test_set = datasets.CIFAR10(DATA_PATH, train=False, download=download, transform=test_transform)
-
- elif dataset == 'cifar100':
- image_size = (32, 32, 3)
- n_classes = 100
- train_set = datasets.CIFAR100(DATA_PATH, train=True, download=download, transform=train_transform)
- test_set = datasets.CIFAR100(DATA_PATH, train=False, download=download, transform=test_transform)
-
- elif dataset == 'svhn':
- assert test_only and image_size is not None
- test_set = datasets.SVHN(DATA_PATH, split='test', download=download, transform=test_transform)
-
- elif dataset == 'lsun_resize':
- assert test_only and image_size is not None
- test_dir = os.path.join(DATA_PATH, 'LSUN_resize')
- test_set = datasets.ImageFolder(test_dir, transform=test_transform)
-
- elif dataset == 'lsun_fix':
- assert test_only and image_size is not None
- test_dir = os.path.join(DATA_PATH, 'LSUN_fix')
- test_set = datasets.ImageFolder(test_dir, transform=test_transform)
-
- elif dataset == 'imagenet_resize':
- assert test_only and image_size is not None
- test_dir = os.path.join(DATA_PATH, 'Imagenet_resize')
- test_set = datasets.ImageFolder(test_dir, transform=test_transform)
-
- elif dataset == 'imagenet_fix':
- assert test_only and image_size is not None
- test_dir = os.path.join(DATA_PATH, 'Imagenet_fix')
- test_set = datasets.ImageFolder(test_dir, transform=test_transform)
-
- elif dataset == 'imagenet':
- image_size = (224, 224, 3)
- n_classes = 30
- train_dir = os.path.join(IMAGENET_PATH, 'one_class_train')
- test_dir = os.path.join(IMAGENET_PATH, 'one_class_test')
- train_set = datasets.ImageFolder(train_dir, transform=train_transform)
- test_set = datasets.ImageFolder(test_dir, transform=test_transform)
-
- elif dataset == 'stanford_dogs':
- assert test_only and image_size is not None
- test_dir = os.path.join(DATA_PATH, 'stanford_dogs')
- test_set = datasets.ImageFolder(test_dir, transform=test_transform)
- test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
-
- elif dataset == 'cub':
- assert test_only and image_size is not None
- test_dir = os.path.join(DATA_PATH, 'cub200')
- test_set = datasets.ImageFolder(test_dir, transform=test_transform)
- test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
-
- elif dataset == 'flowers102':
- assert test_only and image_size is not None
- test_dir = os.path.join(DATA_PATH, 'flowers102')
- test_set = datasets.ImageFolder(test_dir, transform=test_transform)
- test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
-
- elif dataset == 'places365':
- assert test_only and image_size is not None
- test_dir = os.path.join(DATA_PATH, 'places365')
- test_set = datasets.ImageFolder(test_dir, transform=test_transform)
- test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
-
- elif dataset == 'food_101':
- assert test_only and image_size is not None
- test_dir = os.path.join(DATA_PATH, 'food-101', 'images')
- test_set = datasets.ImageFolder(test_dir, transform=test_transform)
- test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
-
- elif dataset == 'caltech_256':
- assert test_only and image_size is not None
- test_dir = os.path.join(DATA_PATH, 'caltech-256')
- test_set = datasets.ImageFolder(test_dir, transform=test_transform)
- test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
-
- elif dataset == 'dtd':
- assert test_only and image_size is not None
- test_dir = os.path.join(DATA_PATH, 'dtd', 'images')
- test_set = datasets.ImageFolder(test_dir, transform=test_transform)
- test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
-
- elif dataset == 'pets':
- assert test_only and image_size is not None
- test_dir = os.path.join(DATA_PATH, 'pets')
- test_set = datasets.ImageFolder(test_dir, transform=test_transform)
- test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
-
- else:
- raise NotImplementedError()
-
- if test_only:
- return test_set
- else:
- return train_set, test_set, image_size, n_classes
-
-
- def get_superclass_list(dataset):
- if dataset == 'CNMC':
- return CNMC_SUPERCLASS
- if dataset == 'CNMC_grayscale':
- return CNMC_SUPERCLASS
- elif dataset == 'cifar10':
- return CIFAR10_SUPERCLASS
- elif dataset == 'cifar100':
- return CIFAR100_SUPERCLASS
- elif dataset == 'imagenet':
- return IMAGENET_SUPERCLASS
- else:
- raise NotImplementedError()
-
-
- def get_subclass_dataset(dataset, classes):
- if not isinstance(classes, list):
- classes = [classes]
-
- indices = []
- for idx, tgt in enumerate(dataset.targets):
- if tgt in classes:
- indices.append(idx)
-
- dataset = Subset(dataset, indices)
- return dataset
-
-
- def get_simclr_eval_transform_imagenet(sample_num, resize_factor, resize_fix):
-
- resize_scale = (resize_factor, 1.0) # resize scaling factor
- if resize_fix: # if resize_fix is True, use same scale
- resize_scale = (resize_factor, resize_factor)
-
- transform = transforms.Compose([
- transforms.Resize(256),
- transforms.RandomResizedCrop(224, scale=resize_scale),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- ])
-
- clean_trasform = transforms.Compose([
- transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- ])
-
- transform = MultiDataTransformList(transform, clean_trasform, sample_num)
-
- return transform, transform
-
- def get_simclr_eval_transform_cnmc(sample_num, resize_factor, resize_fix, res, center_crop_size):
-
- resize_scale = (resize_factor, 1.0) # resize scaling factor
- if resize_fix: # if resize_fix is True, use same scale
- resize_scale = (resize_factor, resize_factor)
-
- transform = transforms.Compose([
- transforms.Resize(res),
- transforms.CenterCrop(center_crop_size),
- transforms.RandomVerticalFlip(),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- ])
-
- clean_trasform = transforms.Compose([
- transforms.Resize(res),
- transforms.CenterCrop(center_crop_size),
- transforms.ToTensor(),
- ])
-
- transform = MultiDataTransformList(transform, clean_trasform, sample_num)
-
- return transform, transform
-
|