import re from collections import defaultdict from glob import glob from os.path import join import pandas as pd import torch import torchvision.transforms.functional as TF from PIL import Image from torch.utils.data import Dataset from torchvision import transforms STD_RES = 450 STD_CENTER_CROP = 300 def file_iter(dataroot): for file in glob(join(dataroot, '*', '*', '*')): yield file def file_match_iter(dataroot): pattern = re.compile(r'(?P.*(?P[a-zA-Z0-9_]+)/' r'(?Phem|all)/' r'UID_(?PH?\d+)_(?P\d+)_(?P\d+)_(all|hem).bmp)') for file in file_iter(dataroot): match = pattern.match(file) if match is not None: yield file, match def to_dataframe(dataroot): data = defaultdict(list) keys = ['file', 'fold', 'subject', 'class', 'image', 'cell'] # Load data from the three training folds for file, match in file_match_iter(dataroot): for key in keys: data[key].append(match.group(key)) # Load data from the phase2 validation set phase2 = pd.read_csv(join(dataroot, 'phase2.csv'), header=0, names=['file_id', 'file', 'class']) pattern = re.compile(r'UID_(?PH?\d+)_(?P\d+)_(?P\d+)_(all|hem).bmp') for i, row in phase2.iterrows(): match = pattern.match(row['file_id']) data['file'].append(join(dataroot, f'phase2/{i+1}.bmp')) data['fold'].append('3') data['subject'].append(match.group('subject')) data['class'].append('hem' if row['class'] == 0 else 'all') data['image'].append(match.group('image')) data['cell'].append(match.group('cell')) # Convert to dataframe df = pd.DataFrame(data) df = df.apply(pd.to_numeric, errors='ignore') return df class ISBI2019(Dataset): def __init__(self, df, transform=None): super().__init__() self.transform = transform self.df = df def __len__(self): return len(self.df) def __getitem__(self, index): # Convert tensors to int because pandas screws up otherwise index = int(index) file, cls = self.df.iloc[index][['file', 'class']] img = Image.open(file)#.convert('RGB') cls = 0 if cls == 'hem' else 1 if self.transform is not None: img = self.transform(img) return img, cls def get_class_weights(df): class_weights = torch.FloatTensor([ df.loc[df['class'] == 'hem']['file'].count() / len(df), df.loc[df['class'] == 'all']['file'].count() / len(df), ]).to(dtype=torch.float32) return class_weights def tf_rotation_stack(x, num_rotations=8): xs = [] for i in range(num_rotations): angle = 360 * i / num_rotations xrot = TF.rotate(x, angle) xrot = TF.to_tensor(xrot) xs.append(xrot) xs = torch.stack(xs) return xs def get_tf_train_transform(res): size_factor = int(STD_RES/res) center_crop = int(STD_CENTER_CROP/size_factor) tf_train = transforms.Compose([ transforms.Resize(res), #transforms.CenterCrop(center_crop), transforms.RandomVerticalFlip(), transforms.RandomHorizontalFlip(), transforms.RandomAffine(degrees=360, translate=(0.2, 0.2)), # transforms.Lambda(tf_rotation_stack), transforms.ToTensor(), ]) return tf_train def get_tf_vaild_rot_transform(res): size_factor = int(STD_RES/res) center_crop = int(STD_CENTER_CROP/size_factor) tf_valid_rot = transforms.Compose([ transforms.Resize(res), #transforms.CenterCrop(center_crop), transforms.Lambda(tf_rotation_stack), ]) return tf_valid_rot def get_tf_valid_norot_transform(res): size_factor = int(STD_RES/res) center_crop = int(STD_CENTER_CROP/size_factor) tf_valid_norot = transforms.Compose([ transforms.Resize(res), #transforms.CenterCrop(center_crop), transforms.ToTensor(), ]) return tf_valid_norot def get_dataset(dataroot, folds_train=(0, 1, 2), folds_valid=(3,), tf_train=None, tf_valid=None): if tf_train is None or tf_valid is None: sys.exit("Tranformation is None") df = to_dataframe(dataroot) df_trainset = df.loc[df['fold'].isin(folds_train)] trainset = ISBI2019(df_trainset, transform=tf_train) class_weights = get_class_weights(df_trainset) if folds_valid is not None: df_validset = df.loc[df['fold'].isin(folds_valid)] validset_subjects = df_validset['subject'].values validset = ISBI2019(df_validset, transform=tf_valid) return trainset, validset, validset_subjects, class_weights else: return trainset, class_weights if __name__ == '__main__': import math from tqdm import tqdm df = to_dataframe('data') print(df) print("Examples by fold and class") print(df.groupby(['fold', 'class'])['file'].count()) dataset = ISBI2019(df) mean_height, mean_width = 0, 0 weird_files = [] bound_left, bound_upper, bound_right, bound_lower = math.inf, math.inf, 0, 0 for i, (img, label) in tqdm(enumerate(dataset), total=len(dataset)): left, upper, right, lower = img.getbbox() if left == 0 or upper == 0 or right == 450 or lower == 450: weird_files.append(df.iloc[i]['file']) height = lower - upper width = right - left mean_height = mean_height + (height - mean_height) / (i + 1) mean_width = mean_width + (width - mean_width) / (i + 1) bound_left = min(bound_left, left) bound_upper = min(bound_upper, upper) bound_right = max(bound_right, right) bound_lower = max(bound_lower, lower) print(f"mean_height = {mean_height:.2f}") print(f"mean_width = {mean_width:.2f}") print(f"bound_left = {bound_left:d}") print(f"bound_upper = {bound_upper:d}") print(f"bound_right = {bound_right:d}") print(f"bound_lower = {bound_lower:d}") print("Files that max out at least one border:") for f in weird_files: print(f)