CSI/datasets/prepare_data.py
2022-04-29 19:26:47 +02:00

196 lines
6.6 KiB
Python

import csv
import os
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
import torch
def transform_image(img_in, target_dir, transformation, suffix):
"""
Transforms an image according to provided transformation.
Parameters:
img_in (path): Image to transform
target_dir (path): Destination path
transformation (callable): Transformation to be applied
suffix (str): Suffix of resulting image.
Returns:
binary_sum (str): Binary string of the sum of a and b
"""
if suffix == 'rot':
im = Image.open(img_in)
im = im.rotate(270)
tensor = transforms.ToTensor()(im)
save_image(tensor, target_dir + os.sep + suffix + '.jpg')
elif suffix == 'sobel':
im = Image.open(img_in)
tensor = transforms.ToTensor()(im)
sobel_filter = torch.tensor([[1., 2., 1.], [0., 0., 0.], [-1., -2., -1.]])
f = sobel_filter.expand(1, 3, 3, 3)
tensor = torch.conv2d(tensor, f, stride=1, padding=1 )
save_image(tensor, target_dir + os.sep + suffix + '.jpg')
elif suffix == 'noise':
im = Image.open(img_in)
tensor = transforms.ToTensor()(im)
tensor = tensor + (torch.randn(tensor.size()) * 0.2 + 0)
save_image(tensor, target_dir + os.sep + suffix + '.jpg')
elif suffix == 'cutout':
print("asd")
else:
im = Image.open(img_in)
im_trans = transformation(im)
im_trans.save(target_dir + os.sep + suffix + '.jpg')
def sort_and_rename_images(excel_path: str):
"""Renames images and sorts them according to csv."""
base_dir = excel_path.rsplit(os.sep, 1)[0]
dir_all = base_dir + os.sep + 'all'
if not os.path.isdir(dir_all):
os.mkdir(dir_all)
dir_hem = base_dir + os.sep + 'hem'
if not os.path.isdir(dir_hem):
os.mkdir(dir_hem)
with open(excel_path, mode='r') as file:
csv_file = csv.reader(file)
for lines in csv_file:
print(lines)
if lines[2] == '1':
os.rename(base_dir + os.sep + lines[1], dir_all + os.sep + lines[0])
elif lines[2] == '0':
os.rename(base_dir + os.sep + lines[1], dir_hem + os.sep + lines[0])
def drop_color_channels(source_dir, target_dir, rgb):
"""Rotates all images in in source dir."""
if rgb == 0:
suffix = "red_only"
drop_1 = 1
drop_2 = 2
elif rgb == 1:
suffix = "green_only"
drop_1 = 0
drop_2 = 2
elif rgb == 2:
suffix = "blue_only"
drop_1 = 0
drop_2 = 1
elif rgb == 3:
suffix = "no_red"
drop_1 = 0
elif rgb == 4:
suffix = "no_green"
drop_1 = 1
elif rgb == 5:
suffix = "no_blue"
drop_1 = 2
else:
suffix = ""
print("Invalid RGB-channel")
if suffix != "":
dirs = os.listdir(source_dir)
for item in dirs:
if os.path.isfile(source_dir + os.sep + item):
im = Image.open(source_dir + os.sep + item)
tensor = transforms.ToTensor()(im)
tensor[drop_1, :, :] = 0
if rgb < 3:
tensor[drop_2, :, :] = 0
save_image(tensor, target_dir + os.sep + item, 'bmp')
def rotate_images(target_dir, source_dir, rotate, theta):
"""Rotates all images in in source dir."""
dirs = os.listdir(source_dir)
for item in dirs:
if os.path.isfile(source_dir + os.sep + item):
for i in range(0, rotate):
im = Image.open(source_dir + os.sep + item)
im = im.rotate(i*theta)
tensor = transforms.ToTensor()(im)
save_image(tensor, target_dir + os.sep + str(i) + '_' + item, 'bmp')
def grayscale_image(source_dir, target_dir):
"""Grayscale transforms all images in path."""
t = transforms.Grayscale()
dirs = os.listdir(source_dir)
if not os.path.isdir(target_dir):
os.mkdir(target_dir)
for item in dirs:
if os.path.isfile(source_dir + os.sep + item):
im = Image.open(source_dir + os.sep + item).convert('RGB')
im_resize = t(im)
tensor = transforms.ToTensor()(im_resize)
padding = torch.zeros(1, tensor.shape[1], tensor.shape[2])
tensor = torch.cat((tensor, padding), 0)
im_resize.save(target_dir + os.sep + item, 'bmp')
def resize(source_dir):
"""Rotates all images in in source dir."""
t = transforms.Compose([transforms.Resize((128, 128))])
dirs = os.listdir(source_dir)
target_dir = source_dir + os.sep + 'resized'
if not os.path.isdir(target_dir):
os.mkdir(target_dir)
for item in dirs:
if os.path.isfile(source_dir + os.sep + item):
im = Image.open(source_dir + os.sep + item)
im_resize = t(im)
im_resize.save(source_dir + os.sep + 'resized' + os.sep + item, 'bmp')
def crop_image(source_dir):
"""Center Crops all images in path."""
t = transforms.CenterCrop((224, 224))
dirs = os.listdir(source_dir)
target_dir = source_dir + os.sep + 'cropped'
if not os.path.isdir(target_dir):
os.mkdir(target_dir)
for item in dirs:
if os.path.isfile(source_dir + os.sep + item):
im = Image.open(source_dir + os.sep + item)
im_resize = t(im, )
im_resize.save(source_dir + os.sep + 'cropped' + os.sep + item, 'bmp')
def mk_dirs(target_dir):
dir_0 = target_dir + r"\fold_0"
dir_1 = target_dir + r"\fold_1"
dir_2 = target_dir + r"\fold_2"
dir_3 = target_dir + r"\phase2"
dir_4 = target_dir + r"\phase3"
dir_0_all = dir_0 + r"\all"
dir_0_hem = dir_0 + r"\hem"
dir_1_all = dir_1 + r"\all"
dir_1_hem = dir_1 + r"\hem"
dir_2_all = dir_2 + r"\all"
dir_2_hem = dir_2 + r"\hem"
if not os.path.isdir(dir_0):
os.mkdir(dir_0)
if not os.path.isdir(dir_1):
os.mkdir(dir_1)
if not os.path.isdir(dir_2):
os.mkdir(dir_2)
if not os.path.isdir(dir_3):
os.mkdir(dir_3)
if not os.path.isdir(dir_4):
os.mkdir(dir_4)
if not os.path.isdir(dir_0_all):
os.mkdir(dir_0_all)
if not os.path.isdir(dir_0_hem):
os.mkdir(dir_0_hem)
if not os.path.isdir(dir_1_all):
os.mkdir(dir_1_all)
if not os.path.isdir(dir_1_hem):
os.mkdir(dir_1_hem)
if not os.path.isdir(dir_2_all):
os.mkdir(dir_2_all)
if not os.path.isdir(dir_2_hem):
os.mkdir(dir_2_hem)
return dir_0_all, dir_0_hem, dir_1_all, dir_1_hem, dir_2_all, dir_2_hem, dir_3, dir_4