In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 13KB


  1. import os
  2. import numpy as np
  3. import torch
  4. from torch.utils.data.dataset import Subset
  5. from torchvision import datasets, transforms
  6. from utils.utils import set_random_seed
  7. DATA_PATH = '~/data/'
  8. IMAGENET_PATH = '~/data/ImageNet'
  9. CNMC_PATH = r'~/data/CSI/CNMC_orig'
  10. CNMC_GRAY_PATH = r'~/data/CSI/CNMC_orig_gray'
  11. CNMC_ROT4_PATH = r'~/data/CSI/CNMC_rotated_4'
  12. CIFAR10_SUPERCLASS = list(range(10)) # one class
  13. IMAGENET_SUPERCLASS = list(range(30)) # one class
  14. CNMC_SUPERCLASS = list(range(2)) # one class
  15. STD_RES = 450
  16. STD_CENTER_CROP = 300
  17. CIFAR100_SUPERCLASS = [
  18. [4, 31, 55, 72, 95],
  19. [1, 33, 67, 73, 91],
  20. [54, 62, 70, 82, 92],
  21. [9, 10, 16, 29, 61],
  22. [0, 51, 53, 57, 83],
  23. [22, 25, 40, 86, 87],
  24. [5, 20, 26, 84, 94],
  25. [6, 7, 14, 18, 24],
  26. [3, 42, 43, 88, 97],
  27. [12, 17, 38, 68, 76],
  28. [23, 34, 49, 60, 71],
  29. [15, 19, 21, 32, 39],
  30. [35, 63, 64, 66, 75],
  31. [27, 45, 77, 79, 99],
  32. [2, 11, 36, 46, 98],
  33. [28, 30, 44, 78, 93],
  34. [37, 50, 65, 74, 80],
  35. [47, 52, 56, 59, 96],
  36. [8, 13, 48, 58, 90],
  37. [41, 69, 81, 85, 89],
  38. ]
  39. class MultiDataTransform(object):
  40. def __init__(self, transform):
  41. self.transform1 = transform
  42. self.transform2 = transform
  43. def __call__(self, sample):
  44. x1 = self.transform1(sample)
  45. x2 = self.transform2(sample)
  46. return x1, x2
  47. class MultiDataTransformList(object):
  48. def __init__(self, transform, clean_trasform, sample_num):
  49. self.transform = transform
  50. self.clean_transform = clean_trasform
  51. self.sample_num = sample_num
  52. def __call__(self, sample):
  53. set_random_seed(0)
  54. sample_list = []
  55. for i in range(self.sample_num):
  56. sample_list.append(self.transform(sample))
  57. return sample_list, self.clean_transform(sample)
  58. def get_transform(image_size=None):
  59. # Note: data augmentation is implemented in the layers
  60. # Hence, we only define the identity transformation here
  61. if image_size: # use pre-specified image size
  62. train_transform = transforms.Compose([
  63. transforms.Resize((image_size[0], image_size[1])),
  64. transforms.RandomHorizontalFlip(),
  65. transforms.ToTensor(),
  66. ])
  67. test_transform = transforms.Compose([
  68. transforms.Resize((image_size[0], image_size[1])),
  69. transforms.ToTensor(),
  70. ])
  71. else: # use default image size
  72. train_transform = transforms.Compose([
  73. transforms.ToTensor(),
  74. ])
  75. test_transform = transforms.ToTensor()
  76. return train_transform, test_transform
  77. def get_subset_with_len(dataset, length, shuffle=False):
  78. set_random_seed(0)
  79. dataset_size = len(dataset)
  80. index = np.arange(dataset_size)
  81. if shuffle:
  82. np.random.shuffle(index)
  83. index = torch.from_numpy(index[0:length])
  84. subset = Subset(dataset, index)
  85. assert len(subset) == length
  86. return subset
  87. def get_transform_imagenet():
  88. train_transform = transforms.Compose([
  89. transforms.Resize(256),
  90. transforms.RandomResizedCrop(224),
  91. transforms.RandomHorizontalFlip(),
  92. transforms.ToTensor(),
  93. ])
  94. test_transform = transforms.Compose([
  95. transforms.Resize(256),
  96. transforms.CenterCrop(224),
  97. transforms.ToTensor(),
  98. ])
  99. train_transform = MultiDataTransform(train_transform)
  100. return train_transform, test_transform
  101. def get_transform_cnmc(res, center_crop_size):
  102. train_transform = transforms.Compose([
  103. transforms.Resize(res),
  104. transforms.CenterCrop(center_crop_size),
  105. transforms.RandomVerticalFlip(),
  106. transforms.RandomHorizontalFlip(),
  107. transforms.ToTensor(),
  108. ])
  109. test_transform = transforms.Compose([
  110. transforms.Resize(res),
  111. transforms.CenterCrop(center_crop_size),
  112. transforms.ToTensor(),
  113. ])
  114. train_transform = MultiDataTransform(train_transform)
  115. return train_transform, test_transform
  116. def get_dataset(P, dataset, test_only=False, image_size=None, download=False, eval=False):
  117. if P.res != '':
  118. res = int(P.res.replace('px', ''))
  119. size_factor = int(STD_RES/res) # always remove same portion
  120. center_crop_size = int(STD_CENTER_CROP/size_factor) # remove black border
  121. if dataset in ['CNMC', 'CNMC_grayscale', 'CNMC_ROT4_PATH']:
  122. if eval:
  123. train_transform, test_transform = get_simclr_eval_transform_cnmc(P.ood_samples,
  124. P.resize_factor, P.resize_fix, res, center_crop_size)
  125. else:
  126. train_transform, test_transform = get_transform_cnmc(res, center_crop_size)
  127. elif dataset in ['imagenet', 'cub', 'stanford_dogs', 'flowers102',
  128. 'places365', 'food_101', 'caltech_256', 'dtd', 'pets']:
  129. if eval:
  130. train_transform, test_transform = get_simclr_eval_transform_imagenet(P.ood_samples,
  131. P.resize_factor, P.resize_fix)
  132. else:
  133. train_transform, test_transform = get_transform_imagenet()
  134. else:
  135. train_transform, test_transform = get_transform(image_size=image_size)
  136. if dataset == 'CNMC':
  137. image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3
  138. n_classes = 2
  139. train_dir = os.path.join(CNMC_PATH, '0_training')
  140. test_dir = os.path.join(CNMC_PATH, '1_validation')
  141. train_set = datasets.ImageFolder(train_dir, transform=train_transform)
  142. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  143. elif dataset == 'CNMC_grayscale':
  144. image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3
  145. n_classes = 2
  146. train_dir = os.path.join(CNMC_GRAY_PATH, '0_training')
  147. test_dir = os.path.join(CNMC_GRAY_PATH, '1_validation')
  148. train_set = datasets.ImageFolder(train_dir, transform=train_transform)
  149. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  150. elif dataset == 'cifar10':
  151. image_size = (32, 32, 3)
  152. n_classes = 10
  153. train_set = datasets.CIFAR10(DATA_PATH, train=True, download=download, transform=train_transform)
  154. test_set = datasets.CIFAR10(DATA_PATH, train=False, download=download, transform=test_transform)
  155. elif dataset == 'cifar100':
  156. image_size = (32, 32, 3)
  157. n_classes = 100
  158. train_set = datasets.CIFAR100(DATA_PATH, train=True, download=download, transform=train_transform)
  159. test_set = datasets.CIFAR100(DATA_PATH, train=False, download=download, transform=test_transform)
  160. elif dataset == 'svhn':
  161. assert test_only and image_size is not None
  162. test_set = datasets.SVHN(DATA_PATH, split='test', download=download, transform=test_transform)
  163. elif dataset == 'lsun_resize':
  164. assert test_only and image_size is not None
  165. test_dir = os.path.join(DATA_PATH, 'LSUN_resize')
  166. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  167. elif dataset == 'lsun_fix':
  168. assert test_only and image_size is not None
  169. test_dir = os.path.join(DATA_PATH, 'LSUN_fix')
  170. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  171. elif dataset == 'imagenet_resize':
  172. assert test_only and image_size is not None
  173. test_dir = os.path.join(DATA_PATH, 'Imagenet_resize')
  174. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  175. elif dataset == 'imagenet_fix':
  176. assert test_only and image_size is not None
  177. test_dir = os.path.join(DATA_PATH, 'Imagenet_fix')
  178. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  179. elif dataset == 'imagenet':
  180. image_size = (224, 224, 3)
  181. n_classes = 30
  182. train_dir = os.path.join(IMAGENET_PATH, 'one_class_train')
  183. test_dir = os.path.join(IMAGENET_PATH, 'one_class_test')
  184. train_set = datasets.ImageFolder(train_dir, transform=train_transform)
  185. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  186. elif dataset == 'stanford_dogs':
  187. assert test_only and image_size is not None
  188. test_dir = os.path.join(DATA_PATH, 'stanford_dogs')
  189. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  190. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  191. elif dataset == 'cub':
  192. assert test_only and image_size is not None
  193. test_dir = os.path.join(DATA_PATH, 'cub200')
  194. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  195. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  196. elif dataset == 'flowers102':
  197. assert test_only and image_size is not None
  198. test_dir = os.path.join(DATA_PATH, 'flowers102')
  199. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  200. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  201. elif dataset == 'places365':
  202. assert test_only and image_size is not None
  203. test_dir = os.path.join(DATA_PATH, 'places365')
  204. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  205. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  206. elif dataset == 'food_101':
  207. assert test_only and image_size is not None
  208. test_dir = os.path.join(DATA_PATH, 'food-101', 'images')
  209. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  210. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  211. elif dataset == 'caltech_256':
  212. assert test_only and image_size is not None
  213. test_dir = os.path.join(DATA_PATH, 'caltech-256')
  214. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  215. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  216. elif dataset == 'dtd':
  217. assert test_only and image_size is not None
  218. test_dir = os.path.join(DATA_PATH, 'dtd', 'images')
  219. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  220. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  221. elif dataset == 'pets':
  222. assert test_only and image_size is not None
  223. test_dir = os.path.join(DATA_PATH, 'pets')
  224. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  225. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  226. else:
  227. raise NotImplementedError()
  228. if test_only:
  229. return test_set
  230. else:
  231. return train_set, test_set, image_size, n_classes
  232. def get_superclass_list(dataset):
  233. if dataset == 'CNMC':
  234. return CNMC_SUPERCLASS
  235. if dataset == 'CNMC_grayscale':
  236. return CNMC_SUPERCLASS
  237. elif dataset == 'cifar10':
  238. return CIFAR10_SUPERCLASS
  239. elif dataset == 'cifar100':
  240. return CIFAR100_SUPERCLASS
  241. elif dataset == 'imagenet':
  242. return IMAGENET_SUPERCLASS
  243. else:
  244. raise NotImplementedError()
  245. def get_subclass_dataset(dataset, classes):
  246. if not isinstance(classes, list):
  247. classes = [classes]
  248. indices = []
  249. for idx, tgt in enumerate(dataset.targets):
  250. if tgt in classes:
  251. indices.append(idx)
  252. dataset = Subset(dataset, indices)
  253. return dataset
  254. def get_simclr_eval_transform_imagenet(sample_num, resize_factor, resize_fix):
  255. resize_scale = (resize_factor, 1.0) # resize scaling factor
  256. if resize_fix: # if resize_fix is True, use same scale
  257. resize_scale = (resize_factor, resize_factor)
  258. transform = transforms.Compose([
  259. transforms.Resize(256),
  260. transforms.RandomResizedCrop(224, scale=resize_scale),
  261. transforms.RandomHorizontalFlip(),
  262. transforms.ToTensor(),
  263. ])
  264. clean_trasform = transforms.Compose([
  265. transforms.Resize(256),
  266. transforms.CenterCrop(224),
  267. transforms.ToTensor(),
  268. ])
  269. transform = MultiDataTransformList(transform, clean_trasform, sample_num)
  270. return transform, transform
  271. def get_simclr_eval_transform_cnmc(sample_num, resize_factor, resize_fix, res, center_crop_size):
  272. resize_scale = (resize_factor, 1.0) # resize scaling factor
  273. if resize_fix: # if resize_fix is True, use same scale
  274. resize_scale = (resize_factor, resize_factor)
  275. transform = transforms.Compose([
  276. transforms.Resize(res),
  277. transforms.CenterCrop(center_crop_size),
  278. transforms.RandomVerticalFlip(),
  279. transforms.RandomHorizontalFlip(),
  280. transforms.ToTensor(),
  281. ])
  282. clean_trasform = transforms.Compose([
  283. transforms.Resize(res),
  284. transforms.CenterCrop(center_crop_size),
  285. transforms.ToTensor(),
  286. ])
  287. transform = MultiDataTransformList(transform, clean_trasform, sample_num)
  288. return transform, transform