import os import numpy as np import torch from torch.utils.data.dataset import Subset from torchvision import datasets, transforms from utils.utils import set_random_seed DATA_PATH = '~/data/' IMAGENET_PATH = '~/data/ImageNet' CNMC_PATH = r'~/data/CSI/CNMC_orig' CNMC_GRAY_PATH = r'~/data/CSI/CNMC_orig_gray' CNMC_ROT4_PATH = r'~/data/CSI/CNMC_rotated_4' CIFAR10_SUPERCLASS = list(range(10)) # one class IMAGENET_SUPERCLASS = list(range(30)) # one class CNMC_SUPERCLASS = list(range(2)) # one class STD_RES = 450 STD_CENTER_CROP = 300 CIFAR100_SUPERCLASS = [ [4, 31, 55, 72, 95], [1, 33, 67, 73, 91], [54, 62, 70, 82, 92], [9, 10, 16, 29, 61], [0, 51, 53, 57, 83], [22, 25, 40, 86, 87], [5, 20, 26, 84, 94], [6, 7, 14, 18, 24], [3, 42, 43, 88, 97], [12, 17, 38, 68, 76], [23, 34, 49, 60, 71], [15, 19, 21, 32, 39], [35, 63, 64, 66, 75], [27, 45, 77, 79, 99], [2, 11, 36, 46, 98], [28, 30, 44, 78, 93], [37, 50, 65, 74, 80], [47, 52, 56, 59, 96], [8, 13, 48, 58, 90], [41, 69, 81, 85, 89], ] class MultiDataTransform(object): def __init__(self, transform): self.transform1 = transform self.transform2 = transform def __call__(self, sample): x1 = self.transform1(sample) x2 = self.transform2(sample) return x1, x2 class MultiDataTransformList(object): def __init__(self, transform, clean_trasform, sample_num): self.transform = transform self.clean_transform = clean_trasform self.sample_num = sample_num def __call__(self, sample): set_random_seed(0) sample_list = [] for i in range(self.sample_num): sample_list.append(self.transform(sample)) return sample_list, self.clean_transform(sample) def get_transform(image_size=None): # Note: data augmentation is implemented in the layers # Hence, we only define the identity transformation here if image_size: # use pre-specified image size train_transform = transforms.Compose([ transforms.Resize((image_size[0], image_size[1])), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) test_transform = transforms.Compose([ transforms.Resize((image_size[0], image_size[1])), transforms.ToTensor(), ]) else: # use default image size train_transform = transforms.Compose([ transforms.ToTensor(), ]) test_transform = transforms.ToTensor() return train_transform, test_transform def get_subset_with_len(dataset, length, shuffle=False): set_random_seed(0) dataset_size = len(dataset) index = np.arange(dataset_size) if shuffle: np.random.shuffle(index) index = torch.from_numpy(index[0:length]) subset = Subset(dataset, index) assert len(subset) == length return subset def get_transform_imagenet(): train_transform = transforms.Compose([ transforms.Resize(256), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) test_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), ]) train_transform = MultiDataTransform(train_transform) return train_transform, test_transform def get_transform_cnmc(res, center_crop_size): train_transform = transforms.Compose([ transforms.Resize(res), transforms.CenterCrop(center_crop_size), transforms.RandomVerticalFlip(), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) test_transform = transforms.Compose([ transforms.Resize(res), transforms.CenterCrop(center_crop_size), transforms.ToTensor(), ]) train_transform = MultiDataTransform(train_transform) return train_transform, test_transform def get_dataset(P, dataset, test_only=False, image_size=None, download=False, eval=False): if P.res != '': res = int(P.res.replace('px', '')) size_factor = int(STD_RES/res) # always remove same portion center_crop_size = int(STD_CENTER_CROP/size_factor) # remove black border if dataset in ['CNMC', 'CNMC_grayscale', 'CNMC_ROT4_PATH']: if eval: train_transform, test_transform = get_simclr_eval_transform_cnmc(P.ood_samples, P.resize_factor, P.resize_fix, res, center_crop_size) else: train_transform, test_transform = get_transform_cnmc(res, center_crop_size) elif dataset in ['imagenet', 'cub', 'stanford_dogs', 'flowers102', 'places365', 'food_101', 'caltech_256', 'dtd', 'pets']: if eval: train_transform, test_transform = get_simclr_eval_transform_imagenet(P.ood_samples, P.resize_factor, P.resize_fix) else: train_transform, test_transform = get_transform_imagenet() else: train_transform, test_transform = get_transform(image_size=image_size) if dataset == 'CNMC': image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3 n_classes = 2 train_dir = os.path.join(CNMC_PATH, '0_training') test_dir = os.path.join(CNMC_PATH, '1_validation') train_set = datasets.ImageFolder(train_dir, transform=train_transform) test_set = datasets.ImageFolder(test_dir, transform=test_transform) elif dataset == 'CNMC_grayscale': image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3 n_classes = 2 train_dir = os.path.join(CNMC_GRAY_PATH, '0_training') test_dir = os.path.join(CNMC_GRAY_PATH, '1_validation') train_set = datasets.ImageFolder(train_dir, transform=train_transform) test_set = datasets.ImageFolder(test_dir, transform=test_transform) elif dataset == 'cifar10': image_size = (32, 32, 3) n_classes = 10 train_set = datasets.CIFAR10(DATA_PATH, train=True, download=download, transform=train_transform) test_set = datasets.CIFAR10(DATA_PATH, train=False, download=download, transform=test_transform) elif dataset == 'cifar100': image_size = (32, 32, 3) n_classes = 100 train_set = datasets.CIFAR100(DATA_PATH, train=True, download=download, transform=train_transform) test_set = datasets.CIFAR100(DATA_PATH, train=False, download=download, transform=test_transform) elif dataset == 'svhn': assert test_only and image_size is not None test_set = datasets.SVHN(DATA_PATH, split='test', download=download, transform=test_transform) elif dataset == 'lsun_resize': assert test_only and image_size is not None test_dir = os.path.join(DATA_PATH, 'LSUN_resize') test_set = datasets.ImageFolder(test_dir, transform=test_transform) elif dataset == 'lsun_fix': assert test_only and image_size is not None test_dir = os.path.join(DATA_PATH, 'LSUN_fix') test_set = datasets.ImageFolder(test_dir, transform=test_transform) elif dataset == 'imagenet_resize': assert test_only and image_size is not None test_dir = os.path.join(DATA_PATH, 'Imagenet_resize') test_set = datasets.ImageFolder(test_dir, transform=test_transform) elif dataset == 'imagenet_fix': assert test_only and image_size is not None test_dir = os.path.join(DATA_PATH, 'Imagenet_fix') test_set = datasets.ImageFolder(test_dir, transform=test_transform) elif dataset == 'imagenet': image_size = (224, 224, 3) n_classes = 30 train_dir = os.path.join(IMAGENET_PATH, 'one_class_train') test_dir = os.path.join(IMAGENET_PATH, 'one_class_test') train_set = datasets.ImageFolder(train_dir, transform=train_transform) test_set = datasets.ImageFolder(test_dir, transform=test_transform) elif dataset == 'stanford_dogs': assert test_only and image_size is not None test_dir = os.path.join(DATA_PATH, 'stanford_dogs') test_set = datasets.ImageFolder(test_dir, transform=test_transform) test_set = get_subset_with_len(test_set, length=3000, shuffle=True) elif dataset == 'cub': assert test_only and image_size is not None test_dir = os.path.join(DATA_PATH, 'cub200') test_set = datasets.ImageFolder(test_dir, transform=test_transform) test_set = get_subset_with_len(test_set, length=3000, shuffle=True) elif dataset == 'flowers102': assert test_only and image_size is not None test_dir = os.path.join(DATA_PATH, 'flowers102') test_set = datasets.ImageFolder(test_dir, transform=test_transform) test_set = get_subset_with_len(test_set, length=3000, shuffle=True) elif dataset == 'places365': assert test_only and image_size is not None test_dir = os.path.join(DATA_PATH, 'places365') test_set = datasets.ImageFolder(test_dir, transform=test_transform) test_set = get_subset_with_len(test_set, length=3000, shuffle=True) elif dataset == 'food_101': assert test_only and image_size is not None test_dir = os.path.join(DATA_PATH, 'food-101', 'images') test_set = datasets.ImageFolder(test_dir, transform=test_transform) test_set = get_subset_with_len(test_set, length=3000, shuffle=True) elif dataset == 'caltech_256': assert test_only and image_size is not None test_dir = os.path.join(DATA_PATH, 'caltech-256') test_set = datasets.ImageFolder(test_dir, transform=test_transform) test_set = get_subset_with_len(test_set, length=3000, shuffle=True) elif dataset == 'dtd': assert test_only and image_size is not None test_dir = os.path.join(DATA_PATH, 'dtd', 'images') test_set = datasets.ImageFolder(test_dir, transform=test_transform) test_set = get_subset_with_len(test_set, length=3000, shuffle=True) elif dataset == 'pets': assert test_only and image_size is not None test_dir = os.path.join(DATA_PATH, 'pets') test_set = datasets.ImageFolder(test_dir, transform=test_transform) test_set = get_subset_with_len(test_set, length=3000, shuffle=True) else: raise NotImplementedError() if test_only: return test_set else: return train_set, test_set, image_size, n_classes def get_superclass_list(dataset): if dataset == 'CNMC': return CNMC_SUPERCLASS if dataset == 'CNMC_grayscale': return CNMC_SUPERCLASS elif dataset == 'cifar10': return CIFAR10_SUPERCLASS elif dataset == 'cifar100': return CIFAR100_SUPERCLASS elif dataset == 'imagenet': return IMAGENET_SUPERCLASS else: raise NotImplementedError() def get_subclass_dataset(dataset, classes): if not isinstance(classes, list): classes = [classes] indices = [] for idx, tgt in enumerate(dataset.targets): if tgt in classes: indices.append(idx) dataset = Subset(dataset, indices) return dataset def get_simclr_eval_transform_imagenet(sample_num, resize_factor, resize_fix): resize_scale = (resize_factor, 1.0) # resize scaling factor if resize_fix: # if resize_fix is True, use same scale resize_scale = (resize_factor, resize_factor) transform = transforms.Compose([ transforms.Resize(256), transforms.RandomResizedCrop(224, scale=resize_scale), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) clean_trasform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), ]) transform = MultiDataTransformList(transform, clean_trasform, sample_num) return transform, transform def get_simclr_eval_transform_cnmc(sample_num, resize_factor, resize_fix, res, center_crop_size): resize_scale = (resize_factor, 1.0) # resize scaling factor if resize_fix: # if resize_fix is True, use same scale resize_scale = (resize_factor, resize_factor) transform = transforms.Compose([ transforms.Resize(res), transforms.CenterCrop(center_crop_size), transforms.RandomVerticalFlip(), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) clean_trasform = transforms.Compose([ transforms.Resize(res), transforms.CenterCrop(center_crop_size), transforms.ToTensor(), ]) transform = MultiDataTransformList(transform, clean_trasform, sample_num) return transform, transform