{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# One Shot Learning\n", "\n", "Mit diesem Jupyter-Skript werden alle notwendigen Funktionen zur Umsetzung und Auswertung des OneShot-Learnings (Siamese Networks) von Ohrbildern zusammengefasst. \n", "Das Jupyter-Skript bezieht sich auf die Masterarbeit: \"Verwendung des menschlichen Ohrs zur Personenauthentifizierung an IT-Systemen mittels CNNs\"." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Bibliotheken importieren\n", "\n", "Zunächst werden alle notwendigen Bibliothken importiert" ] }, { "cell_type": "code", "execution_count": 105, "metadata": {}, "outputs": [], "source": [ "## Import Libearies ##\n", "%matplotlib inline\n", "import cv2\n", "from PIL import Image\n", "import PIL.ImageOps \n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import matplotlib.image as mpimg\n", "import pandas as pd\n", "import os\n", "\n", "## Import TORCHVISION with popular datasets, model architectures, \n", "## and common image transformations for computer vision\n", "from torchvision import transforms\n", "import torchvision\n", "import torchvision.utils\n", "import torchvision.models as models\n", "import torchvision.datasets as dset\n", "import torchsummary\n", "\n", "# Import Debugging method for set_trace\n", "from IPython.core.debugger import set_trace\n", "import logging\n", "\n", "## Import Time Features\n", "import datetime\n", "import time\n", "\n", "## Import TORCH for Data-Structures for multi-dimensional tensors \n", "## and mathematical operations\n", "import torch\n", "from torch.utils.tensorboard import SummaryWriter\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, Dataset\n", "import torch.optim as optim\n", "import torch_utils\n", "from torch.autograd import Variable\n", "\n", "# Import Winsound to play a *.wav-File\n", "# and Counter for counting numbers in an array\n", "import winsound\n", "from collections import Counter\n", "\n", "## Import Shutil to copy or delete Files\n", "import shutil\n", "from shutil import copyfile\n", "\n", "## Import Random Libery\n", "import random\n", "from random import shuffle\n", "\n", "## Import IPYWIDGET for interactive HTML widgets for Jupyter notebooks\n", "import ipywidgets as wg\n", "from IPython.display import display" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Festlegung: Prozessor oder Graphikkarte" ] }, { "cell_type": "code", "execution_count": 106, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(DEVICE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Funktion zur Gegenüberstellung zwei Ohren\n", "\n", "Zwei Ohren werden gegenübergestellt und dazu der Constative-Loss Wert angezeigt" ] }, { "cell_type": "code", "execution_count": 107, "metadata": {}, "outputs": [], "source": [ "# build a rectangle in axes coords\n", "left, width = 5, .5\n", "bottom, height = .25, .5\n", "right = left + width\n", "top = bottom + height\n", "\n", "## Function to show Ear images and the Constrative Loss\n", "def imshow(img,text=None,should_save=False):\n", " npimg = img.cpu().numpy()\n", " plt.axis(\"off\")\n", " if text:\n", " plt.text(75, 8, text, style='italic',fontweight='bold',\n", " bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})\n", " plt.imshow(np.transpose(npimg, (1, 2, 0)))\n", " plt.show() \n", "\n", "## Function to show Learning-Curve with matplotlib\n", "def show_plot(iteration,loss):\n", " plt.plot(iteration,loss)\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Datensätze festlegen, anzeigen und einlesen" ] }, { "cell_type": "code", "execution_count": 108, "metadata": {}, "outputs": [], "source": [ "## Indirectr Path for Linux and Windows## \n", "def ChooseDir(datadir):\n", " if(datadir == 'CP'):\n", " return \"Datensaetze/CP\", \"CP\"\n", " if(datadir == 'AMI'):\n", " return \"Datensaetze/AMI\", \"AMI\"\n", " if(datadir == 'AWE'):\n", " return \"Datensaetze/AWE\", \"AWE\"\n", " if(datadir == 'EarVN'):\n", " return \"Datensaetze/EarVN_1_0\", \"EarVN_1_0\"\n", " if(datadir == 'UERC'):\n", " return \"Datensaetze/UERC\", \"UERC\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Voreinstellungen - Hilfsfunktionen\n", "\n", "* **secs_to_HMS():** Berechnte Sekunden in HH:MM:SS um\n", "* **WaitTime_Finished():** Gibt eine Audio-Datei aus, wenn Training und Testen beendet ist" ] }, { "cell_type": "code", "execution_count": 109, "metadata": {}, "outputs": [], "source": [ "## Convert Seconds to Hours, Minutes and Seconds ##\n", "def secs_to_HMS(secs):\n", " if (secs < 3600):\n", " return datetime.datetime.fromtimestamp(secs).strftime('%M:%S'), \"[MM:SS]\"\n", " else:\n", " return datetime.datetime.fromtimestamp(secs).strftime('%H:%M:%S'), \"[HH:MM:SS]\"\n", "\n", "## Acustic Sound if Model-Trainings finished ##\n", "def WaitTime_Finished():\n", " for i in range(2):\n", " files=os.listdir(\"Sound/\")\n", " file=random.choice(files)\n", " winsound.PlaySound(\"Sound/\"+str(file), winsound.SND_FILENAME)\n", " time.sleep(2)" ] }, { "cell_type": "code", "execution_count": 110, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2021_02_27-21_19_36\n" ] } ], "source": [ "## Use actual Time for individual Tag ##\n", "current_time = datetime.datetime.now().strftime(\"%Y_%m_%d-%H_%M_%S\")\n", "print(current_time)" ] }, { "cell_type": "code", "execution_count": 111, "metadata": {}, "outputs": [], "source": [ "## All Networks from PyTorch wich can be pretrained ## \n", "class SiameseNetwork(nn.Module):\n", " def __init__(self):\n", " super(SiameseNetwork, self).__init__()\n", " self.cnn1 = nn.Sequential(\n", " nn.ReflectionPad2d(1),\n", " nn.Conv2d(1, 4, kernel_size=3),\n", " nn.ReLU(inplace=True),\n", " nn.BatchNorm2d(4),\n", " \n", " nn.ReflectionPad2d(1),\n", " nn.Conv2d(4, 8, kernel_size=3),\n", " nn.ReLU(inplace=True),\n", " nn.BatchNorm2d(8),\n", "\n", "\n", " nn.ReflectionPad2d(1),\n", " nn.Conv2d(8, 8, kernel_size=3),\n", " nn.ReLU(inplace=True),\n", " nn.BatchNorm2d(8),\n", " )\n", "\n", " self.fc1 = nn.Sequential(\n", " nn.Linear(8*100*100, 500),\n", " nn.ReLU(inplace=True),\n", "\n", " nn.Linear(500, 500),\n", " nn.ReLU(inplace=True),\n", "\n", " nn.Linear(500, 5))\n", " def forward_once(self, x):\n", " output = self.cnn1(x)\n", " output = output.view(output.size()[0], -1)\n", " output = self.fc1(output)\n", " return output\n", " def forward(self, input1, input2):\n", " output1 = self.forward_once(input1)\n", " output2 = self.forward_once(input2)\n", " return output1, output2\n", "\n", "\n", "def Network_Choice(netw): \n", " if(netw == 'vgg11'):\n", " return models.vgg11(pretrained=True).to(DEVICE), \"VGG11\"\n", " if(netw == 'vgg11bn'):\n", " return models.vgg11_bn(pretrained=True).to(DEVICE), \"VGG11bn\"\n", " if(netw == 'resnet18'):\n", " return models.resnet18(pretrained=True).to(DEVICE), \"ResNet18\"\n", " if(netw == 'resnet34'):\n", " return models.resnet34(pretrained=True).to(DEVICE), \"ResNet34\"\n", " if(netw == 'alexnet'):\n", " return models.alexnet(pretrained=True).to(DEVICE), \"AlexNet\"\n", " if(netw == 'squeezenet1_0'):\n", " return models.squeezenet1_0(pretrained=True).to(DEVICE), \"SqueezeNet-1-0\" \n", " if(netw == 'GoogLeNet'):\n", " return models.googlenet(pretrained=True).to(DEVICE), \"GoogLeNet\" \n", " if(netw == 'shufflenet_v2_x0_5'):\n", " return models.shufflenet_v2_x0_5(pretrained=True).to(DEVICE), \"Shufflenet-v2-x0_5\" \n", " if(netw == 'resnext101_32x8d'):\n", " return models.resnext101_32x8d(pretrained=True).to(DEVICE), \"Resnext101-32x8d\"\n", " if(netw == 'siamese'):\n", " return SiameseNetwork().to(DEVICE), 'Siamese_Network'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Contrastive Loss\n", "\n", "Constrive-Loss Berechnung für ein Siamese-Network " ] }, { "cell_type": "code", "execution_count": 112, "metadata": {}, "outputs": [], "source": [ "class ContrastiveLoss(torch.nn.Module):\n", " def __init__(self, margin=2.0):\n", " super(ContrastiveLoss, self).__init__()\n", " self.margin = margin\n", "\n", " def forward(self, output1, output2, label):\n", " euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)\n", " loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +\n", " (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))\n", " return loss_contrastive" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Konfiguration\n", "\n", "Auswahl von: \n", "* Datensatz\n", "* Netzwerk\n", "* Trainshare\n", "* Batch Size Train\n", "* Batch Size Test\n", "* Learning Rate\n", "* Momentum\n", "\n", "Hinweis: Konfiguration ausführen, Auswahl festlegen und weiter zum nächsten Abschnitt" ] }, { "cell_type": "code", "execution_count": 113, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "97cf4f37c7ed4036a2879dbb139818c5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Dropdown(description='Dataset:', index=1, options=('CP', 'AMI', 'AWE', 'EarVN', 'UERC'), value='AMI')" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b5002f1761e64a5bbbaec4ecd034c112", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Dropdown(description='Network:', options=('vgg11', 'vgg11bn', 'resnet18', 'resnet34', 'alexnet', 'squeezenet1_…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0eb51fd4b5f744e98edad81c104e7f43", "version_major": 2, "version_minor": 0 }, "text/plain": [ "BoundedFloatText(value=0.8, description='Trainshare:', max=1.0, step=0.1)" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "48e119312d5e429ea553518c6d3df1aa", "version_major": 2, "version_minor": 0 }, "text/plain": [ "BoundedIntText(value=4, description='Batch_Train:', min=1)" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "21770422b0304a63907e69a750f6bea2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "BoundedIntText(value=4, description='Batch_Test:', min=1)" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "14af8e2f1cc0414d8ce34c0bd93ddd7c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "BoundedFloatText(value=0.001, description='Learning_Rate', max=1.0, step=0.0001)" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d8b9466f0d994762bd04d858d213ce0a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "BoundedFloatText(value=0.8, description='Momentum: ', max=1.0, step=0.1)" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "datadir_choose = wg.Dropdown(\n", " options=['CP', 'AMI', 'AWE', 'EarVN', 'UERC'],\n", " value='AMI',\n", " description='Dataset:',\n", " disabled=False,\n", " button_style=''\n", ")\n", "\n", "network_string_choose = wg.Dropdown(\n", " options=['vgg11','vgg11bn','resnet18','resnet34','alexnet','squeezenet1_0','GoogLeNet','shufflenet_v2_x0_5','resnext101_32x8d', 'siamese'],\n", " value='vgg11',\n", " description='Network:',\n", " disabled=False,\n", " button_style=''\n", ")\n", "\n", "trainshare_choose = wg.BoundedFloatText(\n", " value=0.8,\n", " min=0,\n", " max=1,\n", " step=0.1,\n", " description='Trainshare:',\n", " disabled=False\n", ")\n", "\n", "batch_train = wg.BoundedIntText(\n", " value=4,\n", " min=1,\n", " max=100,\n", " step=1,\n", " description='Batch_Train:',\n", " disabled=False\n", ")\n", "\n", "batch_test = wg.BoundedIntText(\n", " value=4,\n", " min=1,\n", " max=100,\n", " step=1,\n", " description='Batch_Test:',\n", " disabled=False\n", ")\n", "\n", "learning_rate_choose = wg.BoundedFloatText(\n", " value=0.001,\n", " min=0,\n", " max=1,\n", " step=0.0001,\n", " description='Learning_Rate',\n", " disabled=False\n", ")\n", "\n", "momentum_choose = wg.BoundedFloatText(\n", " value=0.8,\n", " min=0,\n", " max=1,\n", " step=0.1,\n", " description='Momentum: ',\n", " disabled=False\n", ")\n", "\n", "\n", "display(datadir_choose, network_string_choose, trainshare_choose, batch_train, batch_test, learning_rate_choose, momentum_choose)" ] }, { "cell_type": "code", "execution_count": 114, "metadata": {}, "outputs": [], "source": [ "## Define Variables, Arrays for Dataset and Categories ##\n", "dataset_train = []\n", "dataset_test = []\n", "CATEGORIES = []\n", "#img_array = []\n", "network_name = 'nix'\n", "\n", "datadir = datadir_choose.value\n", "network_string = network_string_choose.value\n", "train_share = trainshare_choose.value\n", "batch_size_train = batch_train.value\n", "batch_size_test = batch_test.value\n", "learning_rate = learning_rate_choose.value\n", "momentum = momentum_choose.value" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pfade zum OneShot-Learning\n", "\n", "Pfade zu den AMI-Bildern für das Trainieren und Testen der Siamese/PyTorch-Netzwerke" ] }, { "cell_type": "code", "execution_count": 115, "metadata": {}, "outputs": [], "source": [ "class Config():\n", " ## Define Pathes for Train- and Test-Images\n", " path_OneShot_Train = \"./Datensaetze/AMI_OneShot/OneShot_Train/\"\n", " path_OneShot_Test = \"./Datensaetze/AMI_OneShot/OneShot_Testing/\" " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Netzwerke und Datensätze einlesen" ] }, { "cell_type": "code", "execution_count": 116, "metadata": {}, "outputs": [], "source": [ "## Get Network and Network-Name\n", "network, network_name = Network_Choice(network_string)\n", "## Define Constrative-Loss as Loss-Function\n", "criterion = ContrastiveLoss()\n", "## Define Adam as optimizer\n", "optimizer = optim.Adam(network.parameters(),lr = 0.0005 )\n", "\n", "## Choose Datadirectory to the Dataset ##\n", "DATADIR, Database = ChooseDir(datadir)\n", "## Read and list all Categories of a Dataset \n", "CATEGORIES = os.listdir(DATADIR)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Erstellung oder Löschen des Dataset-Folder für SiameseDataset_Ear\n", "\n", "* **Create_OneShot_File_Images():**\n", "Falls keine Bilder in *./Datensaetze/AMI_OneShot_Train'* oder *./Datensaetze/AMI_OneShot_Test'* vorhanden sind können diese mit *Create_OneShot_File_Images()*. \n", "\n", "* **Delete_OneShot_File_Images():**\n", "Falls andere Datensätze als der AMI-Datensatz gewünscht sind, kann der AMI-Datensatz durch *Delete_OneShot_File_Images():* gelöscht werden " ] }, { "cell_type": "code", "execution_count": 117, "metadata": {}, "outputs": [], "source": [ "## Create Dataset-Images for OneShot Learning\n", "def Create_OneShot_File_Images():\n", " for category in CATEGORIES:\n", " path = os.path.join(DATADIR, category)\n", " class_num = CATEGORIES.index(category)+1\n", " count_train_share = (len(os.listdir(path)))*train_share\n", " counter = 1\n", " os_listdir = os.listdir(path)\n", " try:\n", " ## Create Folders in Train- and Test-Directories\n", " os.mkdir(Config.path_OneShot_Train+str(class_num))\n", " os.mkdir(Config.path_OneShot_Test+str(class_num))\n", " except OSError:\n", " print (\"Creation of the directory %s failed\" % Config.path_OneShot_Train)\n", " else:\n", " print (\"Successfully created the directory %s \" % Config.path_OneShot_Test)\n", "\n", " for img in os_listdir:\n", " try:\n", " ## Copy all Images in Train- and Test-Directories\n", " if(counter <= count_train_share):\n", " shutil.copy(path+'/'+img, Config.path_OneShot_Train+str(class_num))\n", " counter += 1\n", " else:\n", " shutil.copy(path+'/'+img, Config.path_OneShot_Test+str(class_num))\n", " except Exception as e:\n", " pass\n", " \n", "## Delete Dataset-Images for OneShot Learning \n", "def Delete_OneShot_File_Images():\n", " for category in CATEGORIES:\n", " class_num = CATEGORIES.index(category)#1\n", " count_train_share = (len(os.listdir(path)))*train_share\n", " counter = 1\n", " os_listdir = os.listdir(path)\n", " ## Delete all Directories and Images of One Shot Learning\n", " try:\n", " shutil.rmtree(Config.path_OneShot_Train+str(class_num))\n", " shutil.rmtree(Config.path_OneShot_Test+str(class_num))\n", " except OSError:\n", " print (\"Deletion of the directory %s failed\" % Config.path_OneShot_Train)\n", " else:\n", " print (\"Successfully deleted the directory %s\" % Config.path_OneShot_Train)" ] }, { "cell_type": "code", "execution_count": 118, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n", "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n" ] } ], "source": [ "#Create_OneShot_File_Images()\n", "#Delete_OneShot_File_Images()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Klasse zur Erstellung der Bildpaare\n", "\n", "* **SiameseDataset_Ears()**: Erstellt die identische Anzahl an gleichen und ungleichen Bildpaaren und transformiert die Bilder" ] }, { "cell_type": "code", "execution_count": 119, "metadata": {}, "outputs": [], "source": [ "## Create Dataset of Ear-Images for OneShot-Learning\n", "class SiameseDataset_Ears(Dataset):\n", " \n", " def __init__(self,imageFolderDataset,transform=None,should_invert=True):\n", " self.imageFolderDataset = imageFolderDataset \n", " self.transform = transform\n", " self.should_invert = should_invert\n", " \n", " def __getitem__(self,index):\n", " ## Get random Picture from Test-File\n", " img0_tuple = random.choice(self.imageFolderDataset.imgs)\n", " \n", " #we need to make sure approx 50% of images are in the same class\n", " should_get_same_class = random.randint(0,1) \n", " if should_get_same_class:\n", " while True:\n", " #keep looping till the same class image is found\n", " img1_tuple = random.choice(self.imageFolderDataset.imgs) \n", " if img0_tuple[1]==img1_tuple[1]:\n", " break\n", " else:\n", " while True:\n", " #keep looping till a different class image is found \n", " img1_tuple = random.choice(self.imageFolderDataset.imgs) \n", " if img0_tuple[1] !=img1_tuple[1]:\n", " break\n", "\n", " ## Load images and convert to RGB if PyTorch-Network is Choosen\n", " ## Load images and convert to Gray if Siamese-Network is Choosen\n", " if(network_name == ('VGG11' or 'VGG11bn' or 'ResNet18' or 'ResNet34' or 'alexnet' or 'Squeezenet1_0' or 'GoogLeNet' or 'Shufflenet_v2_x0_5' or 'Resnext101_32x8d')):\n", " img0 = Image.open(img0_tuple[0]).convert('RGB')\n", " img1 = Image.open(img1_tuple[0]).convert('RGB') \n", " elif(network_name == ('Siamese_Network')):\n", " img0 = Image.open(img0_tuple[0]).convert(\"L\")\n", " img1 = Image.open(img1_tuple[0]).convert(\"L\")\n", " #img0 = img0.convert(\"L\")\n", " #img1 = img1.convert(\"L\")\n", " else:\n", " print('False Network choosen') \n", " \n", " ## Invert loaded PIL images\n", " if self.should_invert:\n", " img0 = PIL.ImageOps.invert(img0)\n", " img1 = PIL.ImageOps.invert(img1)\n", " \n", " ## Transform \n", " if self.transform is not None:\n", " img0 = self.transform(img0)\n", " img1 = self.transform(img1)\n", " \n", " return img0, img1 , torch.from_numpy(np.array([int(img1_tuple[1]!=img0_tuple[1])],dtype=np.float32))\n", " \n", " def __len__(self):\n", " return len(self.imageFolderDataset.imgs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Definition des Trainings- und Test-Folders" ] }, { "cell_type": "code", "execution_count": 120, "metadata": {}, "outputs": [], "source": [ "## Folder for Train-Images\n", "folder_dataset_train = dset.ImageFolder(root=Config.path_OneShot_Train)\n", "\n", "## Folder for Test-Images\n", "folder_dataset_test = dset.ImageFolder(root=Config.path_OneShot_Test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Definition des Transformers" ] }, { "cell_type": "code", "execution_count": 121, "metadata": {}, "outputs": [], "source": [ "transformer = transforms.ToTensor()\n", "\n", "## Define Transformer for PyTorch-Networks or Siamese Network\n", "if(network_name == ('VGG11' or 'VGG11bn' or 'ResNet18' or 'ResNet34' or 'alexnet' or 'Squeezenet1_0' or 'GoogLeNet' or 'Shufflenet_v2_x0_5' or 'Resnext101_32x8d')):\n", " ## Transformer for PyTorch-Network\n", " transformer = transforms.Compose([\n", " transforms.Resize(256),\n", " transforms.CenterCrop(224),\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4318, 0.4660, 0.5889), (0.1752, 0.1893, 0.2096)),\n", "])\n", "\n", "elif(network_name == ('Siamese_Network')):\n", " ## Transformer for Siamese-Network\n", " transformer=transforms.Compose([transforms.Resize((100,100)),\n", " transforms.ToTensor()\n", " ])\n", "else:\n", " print('False Network choosen')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Laden von Train- und Testdatensätzen mit Transformation" ] }, { "cell_type": "code", "execution_count": 122, "metadata": {}, "outputs": [], "source": [ "## For PyTorch-Networks\n", "siamese_dataset_train = SiameseDataset_Ears(imageFolderDataset=folder_dataset_train,\n", " transform=transformer,\n", " should_invert=False)\n", "\n", "## Create Test-Dataset\n", "siamese_dataset_test = SiameseDataset_Ears(imageFolderDataset=folder_dataset_test,\n", " transform=transformer,\n", " should_invert=True)\n", "\n", "\n", "\n", "## Create Training-DataLoader\n", "train_dataloader = DataLoader(siamese_dataset_train, batch_size=batch_size_train, shuffle=True,)\n", "\n", "## Create Test-DataLoader\n", "test_dataloader = DataLoader(siamese_dataset_test, batch_size=batch_size_test, shuffle=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualisierung der Bildpaare des TrainLoaders\n", " \n", "Bildpaare: *Proband1 oben, Proband2 unten* \n", "0: Gleiches Paar \n", "1: Ungleiches Paar" ] }, { "cell_type": "code", "execution_count": 123, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[[0.]\n", " [1.]\n", " [0.]\n", " [1.]\n", " [1.]\n", " [0.]\n", " [1.]\n", " [1.]]\n" ] } ], "source": [ "vis_dataloader = DataLoader(siamese_dataset_train, shuffle=True, batch_size=8)\n", "\n", "dataiter = iter(vis_dataloader)\n", "\n", "#example_batch = next(dataiter)\n", "example_batch = dataiter.next()\n", "concatenated = torch.cat((example_batch[0],example_batch[1]),0)\n", "imshow(torchvision.utils.make_grid(concatenated))\n", "print(example_batch[2].numpy())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training:" ] }, { "cell_type": "code", "execution_count": 124, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Start Training um: 21:20:32\n", "Epoch number 0\n", " Current loss 308.1065979003906\n", "\n", "Epoch number 0\n", " Current loss 1.138925552368164\n", "\n", "Epoch number 0\n", " Current loss 0.964200496673584\n", "\n", "Epoch number 1\n", " Current loss 1.052968978881836\n", "\n", "Epoch number 1\n", " Current loss 0.9594541788101196\n", "\n", "Epoch number 1\n", " Current loss 0.4107712507247925\n", "\n", "Epoch number 2\n", " Current loss 1.5126802921295166\n", "\n", "Epoch number 2\n", " Current loss 1.127558946609497\n", "\n", "Epoch number 2\n", " Current loss 3.7104902267456055\n", "\n", "Epoch number 3\n", " Current loss 0.7310527563095093\n", "\n", "Epoch number 3\n", " Current loss 0.7123371362686157\n", "\n", "Epoch number 3\n", " Current loss 1.3404734134674072\n", "\n", "Epoch number 4\n", " Current loss 1.0322986841201782\n", "\n", "Epoch number 4\n", " Current loss 1.4310665130615234\n", "\n", "Epoch number 4\n", " Current loss 0.8324761390686035\n", "\n", "Epoch number 5\n", " Current loss 0.8502787351608276\n", "\n", "Epoch number 5\n", " Current loss 1.3122851848602295\n", "\n", "Epoch number 5\n", " Current loss 0.921744167804718\n", "\n", "Epoch number 6\n", " Current loss 1.1442584991455078\n", "\n", "Epoch number 6\n", " Current loss 1.011681318283081\n", "\n", "Epoch number 6\n", " Current loss 0.9135341048240662\n", "\n", "Epoch number 7\n", " Current loss 1.2509589195251465\n", "\n", "Epoch number 7\n", " Current loss 0.7557617425918579\n", "\n", "Epoch number 7\n", " Current loss 0.8743131160736084\n", "\n", "Epoch number 8\n", " Current loss 0.7078773975372314\n", "\n", "Epoch number 8\n", " Current loss 2.043950080871582\n", "\n", "Epoch number 8\n", " Current loss 1.0420727729797363\n", "\n", "Epoch number 9\n", " Current loss 0.9620883464813232\n", "\n", "Epoch number 9\n", " Current loss 0.9625381827354431\n", "\n", "Epoch number 9\n", " Current loss 1.1636301279067993\n", "\n", "Epoch number 10\n", " Current loss 1.5617934465408325\n", "\n", "Epoch number 10\n", " Current loss 1.029129147529602\n", "\n", "Epoch number 10\n", " Current loss 1.2446420192718506\n", "\n", "Epoch number 11\n", " Current loss 1.0907390117645264\n", "\n", "Epoch number 11\n", " Current loss 0.9430091977119446\n", "\n", "Epoch number 11\n", " Current loss 1.0374553203582764\n", "\n", "Epoch number 12\n", " Current loss 1.2946057319641113\n", "\n", "Epoch number 12\n", " Current loss 0.8857178092002869\n", "\n", "Epoch number 12\n", " Current loss 0.9587574005126953\n", "\n", "Epoch number 13\n", " Current loss 0.446734756231308\n", "\n", "Epoch number 13\n", " Current loss 0.9940989017486572\n", "\n", "Epoch number 13\n", " Current loss 0.9222986698150635\n", "\n", "Epoch number 14\n", " Current loss 0.8071630597114563\n", "\n", "Epoch number 14\n", " Current loss 0.9962530136108398\n", "\n", "Epoch number 14\n", " Current loss 0.9935855269432068\n", "\n", "Epoch number 15\n", " Current loss 1.2180464267730713\n", "\n", "Epoch number 15\n", " Current loss 1.0259199142456055\n", "\n", "Epoch number 15\n", " Current loss 0.8863251209259033\n", "\n", "Epoch number 16\n", " Current loss 1.0609502792358398\n", "\n", "Epoch number 16\n", " Current loss 0.9913638234138489\n", "\n", "Epoch number 16\n", " Current loss 1.1235753297805786\n", "\n", "Epoch number 17\n", " Current loss 1.1986947059631348\n", "\n", "Epoch number 17\n", " Current loss 0.7384413480758667\n", "\n", "Epoch number 17\n", " Current loss 0.8233901262283325\n", "\n", "Epoch number 18\n", " Current loss 1.227494716644287\n", "\n", "Epoch number 18\n", " Current loss 0.9361811876296997\n", "\n", "Epoch number 18\n", " Current loss 0.19040855765342712\n", "\n", "Epoch number 19\n", " Current loss 0.5063803195953369\n", "\n", "Epoch number 19\n", " Current loss 1.4189283847808838\n", "\n", "Epoch number 19\n", " Current loss 0.782526969909668\n", "\n", "Ende Training um: 21:21:26\n", "Dauer Training: 00:53 [MM:SS] \n", "\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "epoch_range = 20\n", "counter = []\n", "loss_history = [] \n", "iteration_number= 0\n", "\n", "\n", "## Start Training\n", "print(\"Start Training um: \", time.strftime(\"%H:%M:%S\"))\n", "start_time = time.time()\n", "\n", "for epoch in range(epoch_range):\n", " for i, data in enumerate(train_dataloader, 0):\n", " img0, img1 , label = data\n", " img0, img1 , label = img0.to(DEVICE), img1.to(DEVICE), label.to(DEVICE)\n", "\n", " optimizer.zero_grad()\n", " \n", " ## Calculate Ouputs and LossConstrative\n", " if(network_name == ('VGG11' or 'VGG11bn' or 'ResNet18' or 'ResNet34' or 'alexnet' or 'Squeezenet1_0' or 'GoogLeNet' or 'Shufflenet_v2_x0_5' or 'Resnext101_32x8d')):\n", " output1 = network(img0)\n", " output2 = network(img1)\n", " loss_contrastive = criterion(output1,output2,label)\n", " \n", " elif(network_name == ('Siamese_Network')):\n", " output1, output2 = network(img0, img1)\n", " #outputs = network(img0, img1)\n", " loss_contrastive = criterion(output1,output2,label)\n", " else:\n", " print('False Network choosen')\n", " \n", " \n", " ## Backpropagation and Optmizer\n", " loss_contrastive.backward()\n", " optimizer.step()\n", " \n", " ## Get loss\n", " if i %10 == 0 :\n", " print(\"Epoch number {}\\n Current loss {}\\n\".format(epoch,loss_contrastive.item()))\n", " iteration_number += 10\n", " counter.append(iteration_number)\n", " loss_history.append(loss_contrastive.item())\n", "\n", " \n", "## Finished Training\n", "print(\"Ende Training um: \",time.strftime(\"%H:%M:%S\"))\n", "stop_time = time.time()\n", "time_dif, time_format = secs_to_HMS(stop_time-start_time)\n", "print(\"Dauer Training: \", time_dif, \" \", time_format, \" \\n\") \n", "\n", "show_plot(counter,loss_history)\n", "WaitTime_Finished()" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "show_plot(counter,loss_history)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Speichern der trainierten Netzwerke" ] }, { "cell_type": "code", "execution_count": 125, "metadata": {}, "outputs": [], "source": [ "## Path to save Network \n", "PATH = './Netzwerke_OneShot/' + current_time + '_' + Database + '_' + network_name + '_Train' + '.pth'\n", "\n", "## Save Network\n", "torch.save({\n", " 'epoch_range': epoch_range,\n", " 'model_state_dict': network.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'loss': loss_contrastive\n", " }, PATH)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Testen" ] }, { "cell_type": "code", "execution_count": 126, "metadata": {}, "outputs": [], "source": [ "folder_dataset_test = dset.ImageFolder(root=Config.path_OneShot_Test)\n", "\n", "siamese_dataset = SiameseDataset_Ears(imageFolderDataset=folder_dataset_test,\n", " transform=transformer\n", " ,should_invert=False)\n", "\n", "test_dataloader = DataLoader(siamese_dataset,batch_size=1,shuffle=True)\n", "\n", "## Test Function for Siamese-Network\n", "def evaluation(model, test_loader):\n", " with torch.no_grad():\n", " model.eval()\n", " correct = 0\n", " count = 0\n", "\n", " for mainImg, imgSets, label in test_loader:\n", " mainImg, imgSets, label = mainImg.to(DEVICE), imgSets.to(DEVICE), label.to(DEVICE)\n", " predVal = 2.1\n", " pred = -1\n", " count += 1\n", " \n", " ## Determine which category an image belongs to\n", " for i, testImg in enumerate(imgSets):\n", " testImg = testImg.to(DEVICE)\n", "\n", " if(network_name == ('VGG11' or 'VGG11bn' or 'ResNet18' or 'ResNet34' or 'alexnet' or 'Squeezenet1_0' or 'GoogLeNet' or 'Shufflenet_v2_x0_5' or 'Resnext101_32x8d')):\n", " output1 = network(mainImg)\n", " output2 = network(imgSets)\n", " \n", " elif(network_name == ('Siamese_Network')):\n", " output1, output2 = network(mainImg, imgSets)\n", " \n", " else:\n", " print('False Network choosen')\n", "\n", " \n", " euclidean_distance = F.pairwise_distance(output1, output2)\n", " euclidean_distance = euclidean_distance.cpu().numpy()\n", " \n", "\n", " if(((euclidean_distance < predVal) and (label==0)) or ((euclidean_distance > predVal) and (label==1))):\n", " correct += 1\n", " print('Accuracy: {}'.format(correct/count))" ] }, { "cell_type": "code", "execution_count": 127, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 1.0\n", "Accuracy: 0.5\n", "Accuracy: 0.3333333333333333\n", "Accuracy: 0.25\n", "Accuracy: 0.4\n", "Accuracy: 0.3333333333333333\n", "Accuracy: 0.2857142857142857\n", "Accuracy: 0.375\n", "Accuracy: 0.3333333333333333\n", "Accuracy: 0.3\n", "Accuracy: 0.2727272727272727\n", "Accuracy: 0.25\n", "Accuracy: 0.3076923076923077\n", "Accuracy: 0.35714285714285715\n", "Accuracy: 0.4\n", "Accuracy: 0.375\n", "Accuracy: 0.35294117647058826\n", "Accuracy: 0.3888888888888889\n", "Accuracy: 0.42105263157894735\n", "Accuracy: 0.45\n", "Accuracy: 0.42857142857142855\n", "Accuracy: 0.4090909090909091\n", "Accuracy: 0.391304347826087\n", "Accuracy: 0.375\n", "Accuracy: 0.4\n", "Accuracy: 0.38461538461538464\n", "Accuracy: 0.37037037037037035\n", "Accuracy: 0.39285714285714285\n", "Accuracy: 0.3793103448275862\n", "Accuracy: 0.36666666666666664\n", "Accuracy: 0.3870967741935484\n", "Accuracy: 0.375\n", "Accuracy: 0.3939393939393939\n", "Accuracy: 0.4117647058823529\n", "Accuracy: 0.42857142857142855\n", "Accuracy: 0.4166666666666667\n", "Accuracy: 0.40540540540540543\n", "Accuracy: 0.42105263157894735\n", "Accuracy: 0.4358974358974359\n", "Accuracy: 0.45\n" ] } ], "source": [ "evaluation(network, test_dataloader)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }