In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. import os
  2. import numpy as np
  3. import torch
  4. from torch.utils.data.dataset import Subset
  5. from torchvision import datasets, transforms
  6. from utils.utils import set_random_seed
  7. DATA_PATH = '~/data/'
  8. IMAGENET_PATH = '~/data/ImageNet'
  9. CNMC_PATH = r'~/data/CSI/CNMC_orig'
  10. CNMC_GRAY_PATH = r'~/data/CSI/CNMC_orig_gray'
  11. CNMC_ROT4_PATH = r'~/data/CSI/CNMC_rotated_4'
  12. CIFAR10_SUPERCLASS = list(range(10)) # one class
  13. IMAGENET_SUPERCLASS = list(range(30)) # one class
  14. CNMC_SUPERCLASS = list(range(2)) # one class
  15. STD_RES = 450
  16. STD_CENTER_CROP = 300
  17. CIFAR100_SUPERCLASS = [
  18. [4, 31, 55, 72, 95],
  19. [1, 33, 67, 73, 91],
  20. [54, 62, 70, 82, 92],
  21. [9, 10, 16, 29, 61],
  22. [0, 51, 53, 57, 83],
  23. [22, 25, 40, 86, 87],
  24. [5, 20, 26, 84, 94],
  25. [6, 7, 14, 18, 24],
  26. [3, 42, 43, 88, 97],
  27. [12, 17, 38, 68, 76],
  28. [23, 34, 49, 60, 71],
  29. [15, 19, 21, 32, 39],
  30. [35, 63, 64, 66, 75],
  31. [27, 45, 77, 79, 99],
  32. [2, 11, 36, 46, 98],
  33. [28, 30, 44, 78, 93],
  34. [37, 50, 65, 74, 80],
  35. [47, 52, 56, 59, 96],
  36. [8, 13, 48, 58, 90],
  37. [41, 69, 81, 85, 89],
  38. ]
  39. class MultiDataTransform(object):
  40. def __init__(self, transform):
  41. self.transform1 = transform
  42. self.transform2 = transform
  43. def __call__(self, sample):
  44. x1 = self.transform1(sample)
  45. x2 = self.transform2(sample)
  46. return x1, x2
  47. class MultiDataTransformList(object):
  48. def __init__(self, transform, clean_trasform, sample_num):
  49. self.transform = transform
  50. self.clean_transform = clean_trasform
  51. self.sample_num = sample_num
  52. def __call__(self, sample):
  53. set_random_seed(0)
  54. sample_list = []
  55. for i in range(self.sample_num):
  56. sample_list.append(self.transform(sample))
  57. return sample_list, self.clean_transform(sample)
  58. def get_transform(image_size=None):
  59. # Note: data augmentation is implemented in the layers
  60. # Hence, we only define the identity transformation here
  61. if image_size: # use pre-specified image size
  62. train_transform = transforms.Compose([
  63. transforms.Resize((image_size[0], image_size[1])),
  64. transforms.RandomHorizontalFlip(),
  65. transforms.ToTensor(),
  66. ])
  67. test_transform = transforms.Compose([
  68. transforms.Resize((image_size[0], image_size[1])),
  69. transforms.ToTensor(),
  70. ])
  71. else: # use default image size
  72. train_transform = transforms.Compose([
  73. transforms.ToTensor(),
  74. ])
  75. test_transform = transforms.ToTensor()
  76. return train_transform, test_transform
  77. def get_subset_with_len(dataset, length, shuffle=False):
  78. set_random_seed(0)
  79. dataset_size = len(dataset)
  80. index = np.arange(dataset_size)
  81. if shuffle:
  82. np.random.shuffle(index)
  83. index = torch.from_numpy(index[0:length])
  84. subset = Subset(dataset, index)
  85. assert len(subset) == length
  86. return subset
  87. def get_transform_imagenet():
  88. train_transform = transforms.Compose([
  89. transforms.Resize(256),
  90. transforms.RandomResizedCrop(224),
  91. transforms.RandomHorizontalFlip(),
  92. transforms.ToTensor(),
  93. ])
  94. test_transform = transforms.Compose([
  95. transforms.Resize(256),
  96. transforms.CenterCrop(224),
  97. transforms.ToTensor(),
  98. ])
  99. train_transform = MultiDataTransform(train_transform)
  100. return train_transform, test_transform
  101. def get_transform_cnmc(res, center_crop_size):
  102. train_transform = transforms.Compose([
  103. transforms.Resize(res),
  104. transforms.CenterCrop(center_crop_size),
  105. transforms.RandomVerticalFlip(),
  106. transforms.RandomHorizontalFlip(),
  107. transforms.ToTensor(),
  108. ])
  109. test_transform = transforms.Compose([
  110. transforms.Resize(res),
  111. transforms.CenterCrop(center_crop_size),
  112. transforms.ToTensor(),
  113. ])
  114. train_transform = MultiDataTransform(train_transform)
  115. return train_transform, test_transform
  116. def get_dataset(P, dataset, test_only=False, image_size=None, download=False, eval=False):
  117. if P.res != '':
  118. res = int(P.res.replace('px', ''))
  119. size_factor = int(STD_RES/res) # always remove same portion
  120. center_crop_size = int(STD_CENTER_CROP/size_factor) # remove black border
  121. if dataset in ['CNMC', 'CNMC_grayscale', 'CNMC_ROT4_PATH']:
  122. if eval:
  123. train_transform, test_transform = get_simclr_eval_transform_cnmc(P.ood_samples,
  124. P.resize_factor, P.resize_fix, res, center_crop_size)
  125. else:
  126. train_transform, test_transform = get_transform_cnmc(res, center_crop_size)
  127. elif dataset in ['imagenet', 'cub', 'stanford_dogs', 'flowers102',
  128. 'places365', 'food_101', 'caltech_256', 'dtd', 'pets']:
  129. if eval:
  130. train_transform, test_transform = get_simclr_eval_transform_imagenet(P.ood_samples,
  131. P.resize_factor, P.resize_fix)
  132. else:
  133. train_transform, test_transform = get_transform_imagenet()
  134. else:
  135. train_transform, test_transform = get_transform(image_size=image_size)
  136. if dataset == 'CNMC':
  137. image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3
  138. n_classes = 2
  139. train_dir = os.path.join(CNMC_PATH, '0_training')
  140. test_dir = os.path.join(CNMC_PATH, '1_validation')
  141. train_set = datasets.ImageFolder(train_dir, transform=train_transform)
  142. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  143. elif dataset == 'CNMC_grayscale':
  144. image_size = (center_crop_size, center_crop_size, 3) #original 450,450,3
  145. n_classes = 2
  146. train_dir = os.path.join(CNMC_GRAY_PATH, '0_training')
  147. test_dir = os.path.join(CNMC_GRAY_PATH, '1_validation')
  148. train_set = datasets.ImageFolder(train_dir, transform=train_transform)
  149. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  150. elif dataset == 'cifar10':
  151. image_size = (32, 32, 3)
  152. n_classes = 10
  153. train_set = datasets.CIFAR10(DATA_PATH, train=True, download=download, transform=train_transform)
  154. test_set = datasets.CIFAR10(DATA_PATH, train=False, download=download, transform=test_transform)
  155. elif dataset == 'cifar100':
  156. image_size = (32, 32, 3)
  157. n_classes = 100
  158. train_set = datasets.CIFAR100(DATA_PATH, train=True, download=download, transform=train_transform)
  159. test_set = datasets.CIFAR100(DATA_PATH, train=False, download=download, transform=test_transform)
  160. elif dataset == 'svhn':
  161. assert test_only and image_size is not None
  162. test_set = datasets.SVHN(DATA_PATH, split='test', download=download, transform=test_transform)
  163. elif dataset == 'lsun_resize':
  164. assert test_only and image_size is not None
  165. test_dir = os.path.join(DATA_PATH, 'LSUN_resize')
  166. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  167. elif dataset == 'lsun_fix':
  168. assert test_only and image_size is not None
  169. test_dir = os.path.join(DATA_PATH, 'LSUN_fix')
  170. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  171. elif dataset == 'imagenet_resize':
  172. assert test_only and image_size is not None
  173. test_dir = os.path.join(DATA_PATH, 'Imagenet_resize')
  174. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  175. elif dataset == 'imagenet_fix':
  176. assert test_only and image_size is not None
  177. test_dir = os.path.join(DATA_PATH, 'Imagenet_fix')
  178. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  179. elif dataset == 'imagenet':
  180. image_size = (224, 224, 3)
  181. n_classes = 30
  182. train_dir = os.path.join(IMAGENET_PATH, 'one_class_train')
  183. test_dir = os.path.join(IMAGENET_PATH, 'one_class_test')
  184. train_set = datasets.ImageFolder(train_dir, transform=train_transform)
  185. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  186. elif dataset == 'stanford_dogs':
  187. assert test_only and image_size is not None
  188. test_dir = os.path.join(DATA_PATH, 'stanford_dogs')
  189. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  190. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  191. elif dataset == 'cub':
  192. assert test_only and image_size is not None
  193. test_dir = os.path.join(DATA_PATH, 'cub200')
  194. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  195. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  196. elif dataset == 'flowers102':
  197. assert test_only and image_size is not None
  198. test_dir = os.path.join(DATA_PATH, 'flowers102')
  199. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  200. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  201. elif dataset == 'places365':
  202. assert test_only and image_size is not None
  203. test_dir = os.path.join(DATA_PATH, 'places365')
  204. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  205. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  206. elif dataset == 'food_101':
  207. assert test_only and image_size is not None
  208. test_dir = os.path.join(DATA_PATH, 'food-101', 'images')
  209. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  210. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  211. elif dataset == 'caltech_256':
  212. assert test_only and image_size is not None
  213. test_dir = os.path.join(DATA_PATH, 'caltech-256')
  214. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  215. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  216. elif dataset == 'dtd':
  217. assert test_only and image_size is not None
  218. test_dir = os.path.join(DATA_PATH, 'dtd', 'images')
  219. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  220. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  221. elif dataset == 'pets':
  222. assert test_only and image_size is not None
  223. test_dir = os.path.join(DATA_PATH, 'pets')
  224. test_set = datasets.ImageFolder(test_dir, transform=test_transform)
  225. test_set = get_subset_with_len(test_set, length=3000, shuffle=True)
  226. else:
  227. raise NotImplementedError()
  228. if test_only:
  229. return test_set
  230. else:
  231. return train_set, test_set, image_size, n_classes
  232. def get_superclass_list(dataset):
  233. if dataset == 'CNMC':
  234. return CNMC_SUPERCLASS
  235. if dataset == 'CNMC_grayscale':
  236. return CNMC_SUPERCLASS
  237. elif dataset == 'cifar10':
  238. return CIFAR10_SUPERCLASS
  239. elif dataset == 'cifar100':
  240. return CIFAR100_SUPERCLASS
  241. elif dataset == 'imagenet':
  242. return IMAGENET_SUPERCLASS
  243. else:
  244. raise NotImplementedError()
  245. def get_subclass_dataset(dataset, classes):
  246. if not isinstance(classes, list):
  247. classes = [classes]
  248. indices = []
  249. for idx, tgt in enumerate(dataset.targets):
  250. if tgt in classes:
  251. indices.append(idx)
  252. dataset = Subset(dataset, indices)
  253. return dataset
  254. def get_simclr_eval_transform_imagenet(sample_num, resize_factor, resize_fix):
  255. resize_scale = (resize_factor, 1.0) # resize scaling factor
  256. if resize_fix: # if resize_fix is True, use same scale
  257. resize_scale = (resize_factor, resize_factor)
  258. transform = transforms.Compose([
  259. transforms.Resize(256),
  260. transforms.RandomResizedCrop(224, scale=resize_scale),
  261. transforms.RandomHorizontalFlip(),
  262. transforms.ToTensor(),
  263. ])
  264. clean_trasform = transforms.Compose([
  265. transforms.Resize(256),
  266. transforms.CenterCrop(224),
  267. transforms.ToTensor(),
  268. ])
  269. transform = MultiDataTransformList(transform, clean_trasform, sample_num)
  270. return transform, transform
  271. def get_simclr_eval_transform_cnmc(sample_num, resize_factor, resize_fix, res, center_crop_size):
  272. resize_scale = (resize_factor, 1.0) # resize scaling factor
  273. if resize_fix: # if resize_fix is True, use same scale
  274. resize_scale = (resize_factor, resize_factor)
  275. transform = transforms.Compose([
  276. transforms.Resize(res),
  277. transforms.CenterCrop(center_crop_size),
  278. transforms.RandomVerticalFlip(),
  279. transforms.RandomHorizontalFlip(),
  280. transforms.ToTensor(),
  281. ])
  282. clean_trasform = transforms.Compose([
  283. transforms.Resize(res),
  284. transforms.CenterCrop(center_crop_size),
  285. transforms.ToTensor(),
  286. ])
  287. transform = MultiDataTransformList(transform, clean_trasform, sample_num)
  288. return transform, transform