In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete Methode des 3. Platzes der ISBI2019.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 6.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import re
  2. from collections import defaultdict
  3. from glob import glob
  4. from os.path import join
  5. import pandas as pd
  6. import torch
  7. import torchvision.transforms.functional as TF
  8. from PIL import Image
  9. from torch.utils.data import Dataset
  10. from torchvision import transforms
  11. STD_RES = 450
  12. STD_CENTER_CROP = 300
  13. def file_iter(dataroot):
  14. for file in glob(join(dataroot, '*', '*', '*')):
  15. yield file
  16. def file_match_iter(dataroot):
  17. pattern = re.compile(r'(?P<file>.*(?P<fold>[a-zA-Z0-9_]+)/'
  18. r'(?P<class>hem|all)/'
  19. r'UID_(?P<subject>H?\d+)_(?P<image>\d+)_(?P<cell>\d+)_(all|hem).bmp)')
  20. for file in file_iter(dataroot):
  21. match = pattern.match(file)
  22. if match is not None:
  23. yield file, match
  24. def to_dataframe(dataroot):
  25. data = defaultdict(list)
  26. keys = ['file', 'fold', 'subject', 'class', 'image', 'cell']
  27. # Load data from the three training folds
  28. for file, match in file_match_iter(dataroot):
  29. for key in keys:
  30. data[key].append(match.group(key))
  31. # Load data from the phase2 validation set
  32. phase2 = pd.read_csv(join(dataroot, 'phase2.csv'), header=0, names=['file_id', 'file', 'class'])
  33. pattern = re.compile(r'UID_(?P<subject>H?\d+)_(?P<image>\d+)_(?P<cell>\d+)_(all|hem).bmp')
  34. for i, row in phase2.iterrows():
  35. match = pattern.match(row['file_id'])
  36. data['file'].append(join(dataroot, f'phase2/{i+1}.bmp'))
  37. data['fold'].append('3')
  38. data['subject'].append(match.group('subject'))
  39. data['class'].append('hem' if row['class'] == 0 else 'all')
  40. data['image'].append(match.group('image'))
  41. data['cell'].append(match.group('cell'))
  42. # Convert to dataframe
  43. df = pd.DataFrame(data)
  44. df = df.apply(pd.to_numeric, errors='ignore')
  45. return df
  46. class ISBI2019(Dataset):
  47. def __init__(self, df, transform=None):
  48. super().__init__()
  49. self.transform = transform
  50. self.df = df
  51. def __len__(self):
  52. return len(self.df)
  53. def __getitem__(self, index):
  54. # Convert tensors to int because pandas screws up otherwise
  55. index = int(index)
  56. file, cls = self.df.iloc[index][['file', 'class']]
  57. img = Image.open(file)#.convert('RGB')
  58. cls = 0 if cls == 'hem' else 1
  59. if self.transform is not None:
  60. img = self.transform(img)
  61. return img, cls
  62. def get_class_weights(df):
  63. class_weights = torch.FloatTensor([
  64. df.loc[df['class'] == 'hem']['file'].count() / len(df),
  65. df.loc[df['class'] == 'all']['file'].count() / len(df),
  66. ]).to(dtype=torch.float32)
  67. return class_weights
  68. def tf_rotation_stack(x, num_rotations=8):
  69. xs = []
  70. for i in range(num_rotations):
  71. angle = 360 * i / num_rotations
  72. xrot = TF.rotate(x, angle)
  73. xrot = TF.to_tensor(xrot)
  74. xs.append(xrot)
  75. xs = torch.stack(xs)
  76. return xs
  77. def get_tf_train_transform(res):
  78. size_factor = int(STD_RES/res)
  79. center_crop = int(STD_CENTER_CROP/size_factor)
  80. tf_train = transforms.Compose([
  81. transforms.Resize(res),
  82. #transforms.CenterCrop(center_crop),
  83. transforms.RandomVerticalFlip(),
  84. transforms.RandomHorizontalFlip(),
  85. transforms.RandomAffine(degrees=360, translate=(0.2, 0.2)),
  86. # transforms.Lambda(tf_rotation_stack),
  87. transforms.ToTensor(),
  88. ])
  89. return tf_train
  90. def get_tf_vaild_rot_transform(res):
  91. size_factor = int(STD_RES/res)
  92. center_crop = int(STD_CENTER_CROP/size_factor)
  93. tf_valid_rot = transforms.Compose([
  94. transforms.Resize(res),
  95. #transforms.CenterCrop(center_crop),
  96. transforms.Lambda(tf_rotation_stack),
  97. ])
  98. return tf_valid_rot
  99. def get_tf_valid_norot_transform(res):
  100. size_factor = int(STD_RES/res)
  101. center_crop = int(STD_CENTER_CROP/size_factor)
  102. tf_valid_norot = transforms.Compose([
  103. transforms.Resize(res),
  104. #transforms.CenterCrop(center_crop),
  105. transforms.ToTensor(),
  106. ])
  107. return tf_valid_norot
  108. def get_dataset(dataroot, folds_train=(0, 1, 2), folds_valid=(3,), tf_train=None, tf_valid=None):
  109. if tf_train is None or tf_valid is None:
  110. sys.exit("Tranformation is None")
  111. df = to_dataframe(dataroot)
  112. df_trainset = df.loc[df['fold'].isin(folds_train)]
  113. trainset = ISBI2019(df_trainset, transform=tf_train)
  114. class_weights = get_class_weights(df_trainset)
  115. if folds_valid is not None:
  116. df_validset = df.loc[df['fold'].isin(folds_valid)]
  117. validset_subjects = df_validset['subject'].values
  118. validset = ISBI2019(df_validset, transform=tf_valid)
  119. return trainset, validset, validset_subjects, class_weights
  120. else:
  121. return trainset, class_weights
  122. if __name__ == '__main__':
  123. import math
  124. from tqdm import tqdm
  125. df = to_dataframe('data')
  126. print(df)
  127. print("Examples by fold and class")
  128. print(df.groupby(['fold', 'class'])['file'].count())
  129. dataset = ISBI2019(df)
  130. mean_height, mean_width = 0, 0
  131. weird_files = []
  132. bound_left, bound_upper, bound_right, bound_lower = math.inf, math.inf, 0, 0
  133. for i, (img, label) in tqdm(enumerate(dataset), total=len(dataset)):
  134. left, upper, right, lower = img.getbbox()
  135. if left == 0 or upper == 0 or right == 450 or lower == 450:
  136. weird_files.append(df.iloc[i]['file'])
  137. height = lower - upper
  138. width = right - left
  139. mean_height = mean_height + (height - mean_height) / (i + 1)
  140. mean_width = mean_width + (width - mean_width) / (i + 1)
  141. bound_left = min(bound_left, left)
  142. bound_upper = min(bound_upper, upper)
  143. bound_right = max(bound_right, right)
  144. bound_lower = max(bound_lower, lower)
  145. print(f"mean_height = {mean_height:.2f}")
  146. print(f"mean_width = {mean_width:.2f}")
  147. print(f"bound_left = {bound_left:d}")
  148. print(f"bound_upper = {bound_upper:d}")
  149. print(f"bound_right = {bound_right:d}")
  150. print(f"bound_lower = {bound_lower:d}")
  151. print("Files that max out at least one border:")
  152. for f in weird_files:
  153. print(f)