2022-04-29 19:33:43 +02:00

185 lines
6.0 KiB
Python

import re
from collections import defaultdict
from glob import glob
from os.path import join
import pandas as pd
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
STD_RES = 450
STD_CENTER_CROP = 300
def file_iter(dataroot):
for file in glob(join(dataroot, '*', '*', '*')):
yield file
def file_match_iter(dataroot):
pattern = re.compile(r'(?P<file>.*(?P<fold>[a-zA-Z0-9_]+)/'
r'(?P<class>hem|all)/'
r'UID_(?P<subject>H?\d+)_(?P<image>\d+)_(?P<cell>\d+)_(all|hem).bmp)')
for file in file_iter(dataroot):
match = pattern.match(file)
if match is not None:
yield file, match
def to_dataframe(dataroot):
data = defaultdict(list)
keys = ['file', 'fold', 'subject', 'class', 'image', 'cell']
# Load data from the three training folds
for file, match in file_match_iter(dataroot):
for key in keys:
data[key].append(match.group(key))
# Load data from the phase2 validation set
phase2 = pd.read_csv(join(dataroot, 'phase2.csv'), header=0, names=['file_id', 'file', 'class'])
pattern = re.compile(r'UID_(?P<subject>H?\d+)_(?P<image>\d+)_(?P<cell>\d+)_(all|hem).bmp')
for i, row in phase2.iterrows():
match = pattern.match(row['file_id'])
data['file'].append(join(dataroot, f'phase2/{i+1}.bmp'))
data['fold'].append('3')
data['subject'].append(match.group('subject'))
data['class'].append('hem' if row['class'] == 0 else 'all')
data['image'].append(match.group('image'))
data['cell'].append(match.group('cell'))
# Convert to dataframe
df = pd.DataFrame(data)
df = df.apply(pd.to_numeric, errors='ignore')
return df
class ISBI2019(Dataset):
def __init__(self, df, transform=None):
super().__init__()
self.transform = transform
self.df = df
def __len__(self):
return len(self.df)
def __getitem__(self, index):
# Convert tensors to int because pandas screws up otherwise
index = int(index)
file, cls = self.df.iloc[index][['file', 'class']]
img = Image.open(file)#.convert('RGB')
cls = 0 if cls == 'hem' else 1
if self.transform is not None:
img = self.transform(img)
return img, cls
def get_class_weights(df):
class_weights = torch.FloatTensor([
df.loc[df['class'] == 'hem']['file'].count() / len(df),
df.loc[df['class'] == 'all']['file'].count() / len(df),
]).to(dtype=torch.float32)
return class_weights
def tf_rotation_stack(x, num_rotations=8):
xs = []
for i in range(num_rotations):
angle = 360 * i / num_rotations
xrot = TF.rotate(x, angle)
xrot = TF.to_tensor(xrot)
xs.append(xrot)
xs = torch.stack(xs)
return xs
def get_tf_train_transform(res):
size_factor = int(STD_RES/res)
center_crop = int(STD_CENTER_CROP/size_factor)
tf_train = transforms.Compose([
transforms.Resize(res),
#transforms.CenterCrop(center_crop),
transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip(),
transforms.RandomAffine(degrees=360, translate=(0.2, 0.2)),
# transforms.Lambda(tf_rotation_stack),
transforms.ToTensor(),
])
return tf_train
def get_tf_vaild_rot_transform(res):
size_factor = int(STD_RES/res)
center_crop = int(STD_CENTER_CROP/size_factor)
tf_valid_rot = transforms.Compose([
transforms.Resize(res),
#transforms.CenterCrop(center_crop),
transforms.Lambda(tf_rotation_stack),
])
return tf_valid_rot
def get_tf_valid_norot_transform(res):
size_factor = int(STD_RES/res)
center_crop = int(STD_CENTER_CROP/size_factor)
tf_valid_norot = transforms.Compose([
transforms.Resize(res),
#transforms.CenterCrop(center_crop),
transforms.ToTensor(),
])
return tf_valid_norot
def get_dataset(dataroot, folds_train=(0, 1, 2), folds_valid=(3,), tf_train=None, tf_valid=None):
if tf_train is None or tf_valid is None:
sys.exit("Tranformation is None")
df = to_dataframe(dataroot)
df_trainset = df.loc[df['fold'].isin(folds_train)]
trainset = ISBI2019(df_trainset, transform=tf_train)
class_weights = get_class_weights(df_trainset)
if folds_valid is not None:
df_validset = df.loc[df['fold'].isin(folds_valid)]
validset_subjects = df_validset['subject'].values
validset = ISBI2019(df_validset, transform=tf_valid)
return trainset, validset, validset_subjects, class_weights
else:
return trainset, class_weights
if __name__ == '__main__':
import math
from tqdm import tqdm
df = to_dataframe('data')
print(df)
print("Examples by fold and class")
print(df.groupby(['fold', 'class'])['file'].count())
dataset = ISBI2019(df)
mean_height, mean_width = 0, 0
weird_files = []
bound_left, bound_upper, bound_right, bound_lower = math.inf, math.inf, 0, 0
for i, (img, label) in tqdm(enumerate(dataset), total=len(dataset)):
left, upper, right, lower = img.getbbox()
if left == 0 or upper == 0 or right == 450 or lower == 450:
weird_files.append(df.iloc[i]['file'])
height = lower - upper
width = right - left
mean_height = mean_height + (height - mean_height) / (i + 1)
mean_width = mean_width + (width - mean_width) / (i + 1)
bound_left = min(bound_left, left)
bound_upper = min(bound_upper, upper)
bound_right = max(bound_right, right)
bound_lower = max(bound_lower, lower)
print(f"mean_height = {mean_height:.2f}")
print(f"mean_width = {mean_width:.2f}")
print(f"bound_left = {bound_left:d}")
print(f"bound_upper = {bound_upper:d}")
print(f"bound_right = {bound_right:d}")
print(f"bound_lower = {bound_lower:d}")
print("Files that max out at least one border:")
for f in weird_files:
print(f)