|
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# One Shot Learning\n",
- "\n",
- "Mit diesem Jupyter-Skript werden alle notwendigen Funktionen zur Umsetzung und Auswertung des OneShot-Learnings (Siamese Networks) von Ohrbildern zusammengefasst. \n",
- "Das Jupyter-Skript bezieht sich auf die Masterarbeit: \"Verwendung des menschlichen Ohrs zur Personenauthentifizierung an IT-Systemen mittels CNNs\"."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Bibliotheken importieren\n",
- "\n",
- "Zunächst werden alle notwendigen Bibliothken importiert"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 105,
- "metadata": {},
- "outputs": [],
- "source": [
- "## Import Libearies ##\n",
- "%matplotlib inline\n",
- "import cv2\n",
- "from PIL import Image\n",
- "import PIL.ImageOps \n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "import matplotlib.image as mpimg\n",
- "import pandas as pd\n",
- "import os\n",
- "\n",
- "## Import TORCHVISION with popular datasets, model architectures, \n",
- "## and common image transformations for computer vision\n",
- "from torchvision import transforms\n",
- "import torchvision\n",
- "import torchvision.utils\n",
- "import torchvision.models as models\n",
- "import torchvision.datasets as dset\n",
- "import torchsummary\n",
- "\n",
- "# Import Debugging method for set_trace\n",
- "from IPython.core.debugger import set_trace\n",
- "import logging\n",
- "\n",
- "## Import Time Features\n",
- "import datetime\n",
- "import time\n",
- "\n",
- "## Import TORCH for Data-Structures for multi-dimensional tensors \n",
- "## and mathematical operations\n",
- "import torch\n",
- "from torch.utils.tensorboard import SummaryWriter\n",
- "import torch.nn as nn\n",
- "import torch.nn.functional as F\n",
- "from torch.utils.data import DataLoader, Dataset\n",
- "import torch.optim as optim\n",
- "import torch_utils\n",
- "from torch.autograd import Variable\n",
- "\n",
- "# Import Winsound to play a *.wav-File\n",
- "# and Counter for counting numbers in an array\n",
- "import winsound\n",
- "from collections import Counter\n",
- "\n",
- "## Import Shutil to copy or delete Files\n",
- "import shutil\n",
- "from shutil import copyfile\n",
- "\n",
- "## Import Random Libery\n",
- "import random\n",
- "from random import shuffle\n",
- "\n",
- "## Import IPYWIDGET for interactive HTML widgets for Jupyter notebooks\n",
- "import ipywidgets as wg\n",
- "from IPython.display import display"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Festlegung: Prozessor oder Graphikkarte"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 106,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "cuda\n"
- ]
- }
- ],
- "source": [
- "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
- "print(DEVICE)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Funktion zur Gegenüberstellung zwei Ohren\n",
- "\n",
- "Zwei Ohren werden gegenübergestellt und dazu der Constative-Loss Wert angezeigt"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 107,
- "metadata": {},
- "outputs": [],
- "source": [
- "# build a rectangle in axes coords\n",
- "left, width = 5, .5\n",
- "bottom, height = .25, .5\n",
- "right = left + width\n",
- "top = bottom + height\n",
- "\n",
- "## Function to show Ear images and the Constrative Loss\n",
- "def imshow(img,text=None,should_save=False):\n",
- " npimg = img.cpu().numpy()\n",
- " plt.axis(\"off\")\n",
- " if text:\n",
- " plt.text(75, 8, text, style='italic',fontweight='bold',\n",
- " bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})\n",
- " plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
- " plt.show() \n",
- "\n",
- "## Function to show Learning-Curve with matplotlib\n",
- "def show_plot(iteration,loss):\n",
- " plt.plot(iteration,loss)\n",
- " plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Datensätze festlegen, anzeigen und einlesen"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 108,
- "metadata": {},
- "outputs": [],
- "source": [
- "## Indirectr Path for Linux and Windows## \n",
- "def ChooseDir(datadir):\n",
- " if(datadir == 'CP'):\n",
- " return \"Datensaetze/CP\", \"CP\"\n",
- " if(datadir == 'AMI'):\n",
- " return \"Datensaetze/AMI\", \"AMI\"\n",
- " if(datadir == 'AWE'):\n",
- " return \"Datensaetze/AWE\", \"AWE\"\n",
- " if(datadir == 'EarVN'):\n",
- " return \"Datensaetze/EarVN_1_0\", \"EarVN_1_0\"\n",
- " if(datadir == 'UERC'):\n",
- " return \"Datensaetze/UERC\", \"UERC\""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Voreinstellungen - Hilfsfunktionen\n",
- "\n",
- "* **secs_to_HMS():** Berechnte Sekunden in HH:MM:SS um\n",
- "* **WaitTime_Finished():** Gibt eine Audio-Datei aus, wenn Training und Testen beendet ist"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 109,
- "metadata": {},
- "outputs": [],
- "source": [
- "## Convert Seconds to Hours, Minutes and Seconds ##\n",
- "def secs_to_HMS(secs):\n",
- " if (secs < 3600):\n",
- " return datetime.datetime.fromtimestamp(secs).strftime('%M:%S'), \"[MM:SS]\"\n",
- " else:\n",
- " return datetime.datetime.fromtimestamp(secs).strftime('%H:%M:%S'), \"[HH:MM:SS]\"\n",
- "\n",
- "## Acustic Sound if Model-Trainings finished ##\n",
- "def WaitTime_Finished():\n",
- " for i in range(2):\n",
- " files=os.listdir(\"Sound/\")\n",
- " file=random.choice(files)\n",
- " winsound.PlaySound(\"Sound/\"+str(file), winsound.SND_FILENAME)\n",
- " time.sleep(2)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 110,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "2021_02_27-21_19_36\n"
- ]
- }
- ],
- "source": [
- "## Use actual Time for individual Tag ##\n",
- "current_time = datetime.datetime.now().strftime(\"%Y_%m_%d-%H_%M_%S\")\n",
- "print(current_time)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 111,
- "metadata": {},
- "outputs": [],
- "source": [
- "## All Networks from PyTorch wich can be pretrained ## \n",
- "class SiameseNetwork(nn.Module):\n",
- " def __init__(self):\n",
- " super(SiameseNetwork, self).__init__()\n",
- " self.cnn1 = nn.Sequential(\n",
- " nn.ReflectionPad2d(1),\n",
- " nn.Conv2d(1, 4, kernel_size=3),\n",
- " nn.ReLU(inplace=True),\n",
- " nn.BatchNorm2d(4),\n",
- " \n",
- " nn.ReflectionPad2d(1),\n",
- " nn.Conv2d(4, 8, kernel_size=3),\n",
- " nn.ReLU(inplace=True),\n",
- " nn.BatchNorm2d(8),\n",
- "\n",
- "\n",
- " nn.ReflectionPad2d(1),\n",
- " nn.Conv2d(8, 8, kernel_size=3),\n",
- " nn.ReLU(inplace=True),\n",
- " nn.BatchNorm2d(8),\n",
- " )\n",
- "\n",
- " self.fc1 = nn.Sequential(\n",
- " nn.Linear(8*100*100, 500),\n",
- " nn.ReLU(inplace=True),\n",
- "\n",
- " nn.Linear(500, 500),\n",
- " nn.ReLU(inplace=True),\n",
- "\n",
- " nn.Linear(500, 5))\n",
- " def forward_once(self, x):\n",
- " output = self.cnn1(x)\n",
- " output = output.view(output.size()[0], -1)\n",
- " output = self.fc1(output)\n",
- " return output\n",
- " def forward(self, input1, input2):\n",
- " output1 = self.forward_once(input1)\n",
- " output2 = self.forward_once(input2)\n",
- " return output1, output2\n",
- "\n",
- "\n",
- "def Network_Choice(netw): \n",
- " if(netw == 'vgg11'):\n",
- " return models.vgg11(pretrained=True).to(DEVICE), \"VGG11\"\n",
- " if(netw == 'vgg11bn'):\n",
- " return models.vgg11_bn(pretrained=True).to(DEVICE), \"VGG11bn\"\n",
- " if(netw == 'resnet18'):\n",
- " return models.resnet18(pretrained=True).to(DEVICE), \"ResNet18\"\n",
- " if(netw == 'resnet34'):\n",
- " return models.resnet34(pretrained=True).to(DEVICE), \"ResNet34\"\n",
- " if(netw == 'alexnet'):\n",
- " return models.alexnet(pretrained=True).to(DEVICE), \"AlexNet\"\n",
- " if(netw == 'squeezenet1_0'):\n",
- " return models.squeezenet1_0(pretrained=True).to(DEVICE), \"SqueezeNet-1-0\" \n",
- " if(netw == 'GoogLeNet'):\n",
- " return models.googlenet(pretrained=True).to(DEVICE), \"GoogLeNet\" \n",
- " if(netw == 'shufflenet_v2_x0_5'):\n",
- " return models.shufflenet_v2_x0_5(pretrained=True).to(DEVICE), \"Shufflenet-v2-x0_5\" \n",
- " if(netw == 'resnext101_32x8d'):\n",
- " return models.resnext101_32x8d(pretrained=True).to(DEVICE), \"Resnext101-32x8d\"\n",
- " if(netw == 'siamese'):\n",
- " return SiameseNetwork().to(DEVICE), 'Siamese_Network'"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Contrastive Loss\n",
- "\n",
- "Constrive-Loss Berechnung für ein Siamese-Network "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 112,
- "metadata": {},
- "outputs": [],
- "source": [
- "class ContrastiveLoss(torch.nn.Module):\n",
- " def __init__(self, margin=2.0):\n",
- " super(ContrastiveLoss, self).__init__()\n",
- " self.margin = margin\n",
- "\n",
- " def forward(self, output1, output2, label):\n",
- " euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)\n",
- " loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +\n",
- " (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))\n",
- " return loss_contrastive"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Konfiguration\n",
- "\n",
- "Auswahl von: \n",
- "* Datensatz\n",
- "* Netzwerk\n",
- "* Trainshare\n",
- "* Batch Size Train\n",
- "* Batch Size Test\n",
- "* Learning Rate\n",
- "* Momentum\n",
- "\n",
- "Hinweis: Konfiguration ausführen, Auswahl festlegen und weiter zum nächsten Abschnitt"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 113,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "97cf4f37c7ed4036a2879dbb139818c5",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Dropdown(description='Dataset:', index=1, options=('CP', 'AMI', 'AWE', 'EarVN', 'UERC'), value='AMI')"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "b5002f1761e64a5bbbaec4ecd034c112",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Dropdown(description='Network:', options=('vgg11', 'vgg11bn', 'resnet18', 'resnet34', 'alexnet', 'squeezenet1_…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "0eb51fd4b5f744e98edad81c104e7f43",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "BoundedFloatText(value=0.8, description='Trainshare:', max=1.0, step=0.1)"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "48e119312d5e429ea553518c6d3df1aa",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "BoundedIntText(value=4, description='Batch_Train:', min=1)"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "21770422b0304a63907e69a750f6bea2",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "BoundedIntText(value=4, description='Batch_Test:', min=1)"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "14af8e2f1cc0414d8ce34c0bd93ddd7c",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "BoundedFloatText(value=0.001, description='Learning_Rate', max=1.0, step=0.0001)"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "d8b9466f0d994762bd04d858d213ce0a",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "BoundedFloatText(value=0.8, description='Momentum: ', max=1.0, step=0.1)"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "datadir_choose = wg.Dropdown(\n",
- " options=['CP', 'AMI', 'AWE', 'EarVN', 'UERC'],\n",
- " value='AMI',\n",
- " description='Dataset:',\n",
- " disabled=False,\n",
- " button_style=''\n",
- ")\n",
- "\n",
- "network_string_choose = wg.Dropdown(\n",
- " options=['vgg11','vgg11bn','resnet18','resnet34','alexnet','squeezenet1_0','GoogLeNet','shufflenet_v2_x0_5','resnext101_32x8d', 'siamese'],\n",
- " value='vgg11',\n",
- " description='Network:',\n",
- " disabled=False,\n",
- " button_style=''\n",
- ")\n",
- "\n",
- "trainshare_choose = wg.BoundedFloatText(\n",
- " value=0.8,\n",
- " min=0,\n",
- " max=1,\n",
- " step=0.1,\n",
- " description='Trainshare:',\n",
- " disabled=False\n",
- ")\n",
- "\n",
- "batch_train = wg.BoundedIntText(\n",
- " value=4,\n",
- " min=1,\n",
- " max=100,\n",
- " step=1,\n",
- " description='Batch_Train:',\n",
- " disabled=False\n",
- ")\n",
- "\n",
- "batch_test = wg.BoundedIntText(\n",
- " value=4,\n",
- " min=1,\n",
- " max=100,\n",
- " step=1,\n",
- " description='Batch_Test:',\n",
- " disabled=False\n",
- ")\n",
- "\n",
- "learning_rate_choose = wg.BoundedFloatText(\n",
- " value=0.001,\n",
- " min=0,\n",
- " max=1,\n",
- " step=0.0001,\n",
- " description='Learning_Rate',\n",
- " disabled=False\n",
- ")\n",
- "\n",
- "momentum_choose = wg.BoundedFloatText(\n",
- " value=0.8,\n",
- " min=0,\n",
- " max=1,\n",
- " step=0.1,\n",
- " description='Momentum: ',\n",
- " disabled=False\n",
- ")\n",
- "\n",
- "\n",
- "display(datadir_choose, network_string_choose, trainshare_choose, batch_train, batch_test, learning_rate_choose, momentum_choose)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 114,
- "metadata": {},
- "outputs": [],
- "source": [
- "## Define Variables, Arrays for Dataset and Categories ##\n",
- "dataset_train = []\n",
- "dataset_test = []\n",
- "CATEGORIES = []\n",
- "#img_array = []\n",
- "network_name = 'nix'\n",
- "\n",
- "datadir = datadir_choose.value\n",
- "network_string = network_string_choose.value\n",
- "train_share = trainshare_choose.value\n",
- "batch_size_train = batch_train.value\n",
- "batch_size_test = batch_test.value\n",
- "learning_rate = learning_rate_choose.value\n",
- "momentum = momentum_choose.value"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Pfade zum OneShot-Learning\n",
- "\n",
- "Pfade zu den AMI-Bildern für das Trainieren und Testen der Siamese/PyTorch-Netzwerke"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 115,
- "metadata": {},
- "outputs": [],
- "source": [
- "class Config():\n",
- " ## Define Pathes for Train- and Test-Images\n",
- " path_OneShot_Train = \"./Datensaetze/AMI_OneShot/OneShot_Train/\"\n",
- " path_OneShot_Test = \"./Datensaetze/AMI_OneShot/OneShot_Testing/\" "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Netzwerke und Datensätze einlesen"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 116,
- "metadata": {},
- "outputs": [],
- "source": [
- "## Get Network and Network-Name\n",
- "network, network_name = Network_Choice(network_string)\n",
- "## Define Constrative-Loss as Loss-Function\n",
- "criterion = ContrastiveLoss()\n",
- "## Define Adam as optimizer\n",
- "optimizer = optim.Adam(network.parameters(),lr = 0.0005 )\n",
- "\n",
- "## Choose Datadirectory to the Dataset ##\n",
- "DATADIR, Database = ChooseDir(datadir)\n",
- "## Read and list all Categories of a Dataset \n",
- "CATEGORIES = os.listdir(DATADIR)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Erstellung oder Löschen des Dataset-Folder für SiameseDataset_Ear\n",
- "\n",
- "* **Create_OneShot_File_Images():**\n",
- "Falls keine Bilder in *./Datensaetze/AMI_OneShot_Train'* oder *./Datensaetze/AMI_OneShot_Test'* vorhanden sind können diese mit *Create_OneShot_File_Images()*. \n",
- "\n",
- "* **Delete_OneShot_File_Images():**\n",
- "Falls andere Datensätze als der AMI-Datensatz gewünscht sind, kann der AMI-Datensatz durch *Delete_OneShot_File_Images():* gelöscht werden "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 117,
- "metadata": {},
- "outputs": [],
- "source": [
- "## Create Dataset-Images for OneShot Learning\n",
- "def Create_OneShot_File_Images():\n",
- " for category in CATEGORIES:\n",
- " path = os.path.join(DATADIR, category)\n",
- " class_num = CATEGORIES.index(category)+1\n",
- " count_train_share = (len(os.listdir(path)))*train_share\n",
- " counter = 1\n",
- " os_listdir = os.listdir(path)\n",
- " try:\n",
- " ## Create Folders in Train- and Test-Directories\n",
- " os.mkdir(Config.path_OneShot_Train+str(class_num))\n",
- " os.mkdir(Config.path_OneShot_Test+str(class_num))\n",
- " except OSError:\n",
- " print (\"Creation of the directory %s failed\" % Config.path_OneShot_Train)\n",
- " else:\n",
- " print (\"Successfully created the directory %s \" % Config.path_OneShot_Test)\n",
- "\n",
- " for img in os_listdir:\n",
- " try:\n",
- " ## Copy all Images in Train- and Test-Directories\n",
- " if(counter <= count_train_share):\n",
- " shutil.copy(path+'/'+img, Config.path_OneShot_Train+str(class_num))\n",
- " counter += 1\n",
- " else:\n",
- " shutil.copy(path+'/'+img, Config.path_OneShot_Test+str(class_num))\n",
- " except Exception as e:\n",
- " pass\n",
- " \n",
- "## Delete Dataset-Images for OneShot Learning \n",
- "def Delete_OneShot_File_Images():\n",
- " for category in CATEGORIES:\n",
- " class_num = CATEGORIES.index(category)#1\n",
- " count_train_share = (len(os.listdir(path)))*train_share\n",
- " counter = 1\n",
- " os_listdir = os.listdir(path)\n",
- " ## Delete all Directories and Images of One Shot Learning\n",
- " try:\n",
- " shutil.rmtree(Config.path_OneShot_Train+str(class_num))\n",
- " shutil.rmtree(Config.path_OneShot_Test+str(class_num))\n",
- " except OSError:\n",
- " print (\"Deletion of the directory %s failed\" % Config.path_OneShot_Train)\n",
- " else:\n",
- " print (\"Successfully deleted the directory %s\" % Config.path_OneShot_Train)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 118,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
- "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n"
- ]
- }
- ],
- "source": [
- "#Create_OneShot_File_Images()\n",
- "#Delete_OneShot_File_Images()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Klasse zur Erstellung der Bildpaare\n",
- "\n",
- "* **SiameseDataset_Ears()**: Erstellt die identische Anzahl an gleichen und ungleichen Bildpaaren und transformiert die Bilder"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 119,
- "metadata": {},
- "outputs": [],
- "source": [
- "## Create Dataset of Ear-Images for OneShot-Learning\n",
- "class SiameseDataset_Ears(Dataset):\n",
- " \n",
- " def __init__(self,imageFolderDataset,transform=None,should_invert=True):\n",
- " self.imageFolderDataset = imageFolderDataset \n",
- " self.transform = transform\n",
- " self.should_invert = should_invert\n",
- " \n",
- " def __getitem__(self,index):\n",
- " ## Get random Picture from Test-File\n",
- " img0_tuple = random.choice(self.imageFolderDataset.imgs)\n",
- " \n",
- " #we need to make sure approx 50% of images are in the same class\n",
- " should_get_same_class = random.randint(0,1) \n",
- " if should_get_same_class:\n",
- " while True:\n",
- " #keep looping till the same class image is found\n",
- " img1_tuple = random.choice(self.imageFolderDataset.imgs) \n",
- " if img0_tuple[1]==img1_tuple[1]:\n",
- " break\n",
- " else:\n",
- " while True:\n",
- " #keep looping till a different class image is found \n",
- " img1_tuple = random.choice(self.imageFolderDataset.imgs) \n",
- " if img0_tuple[1] !=img1_tuple[1]:\n",
- " break\n",
- "\n",
- " ## Load images and convert to RGB if PyTorch-Network is Choosen\n",
- " ## Load images and convert to Gray if Siamese-Network is Choosen\n",
- " if(network_name == ('VGG11' or 'VGG11bn' or 'ResNet18' or 'ResNet34' or 'alexnet' or 'Squeezenet1_0' or 'GoogLeNet' or 'Shufflenet_v2_x0_5' or 'Resnext101_32x8d')):\n",
- " img0 = Image.open(img0_tuple[0]).convert('RGB')\n",
- " img1 = Image.open(img1_tuple[0]).convert('RGB') \n",
- " elif(network_name == ('Siamese_Network')):\n",
- " img0 = Image.open(img0_tuple[0]).convert(\"L\")\n",
- " img1 = Image.open(img1_tuple[0]).convert(\"L\")\n",
- " #img0 = img0.convert(\"L\")\n",
- " #img1 = img1.convert(\"L\")\n",
- " else:\n",
- " print('False Network choosen') \n",
- " \n",
- " ## Invert loaded PIL images\n",
- " if self.should_invert:\n",
- " img0 = PIL.ImageOps.invert(img0)\n",
- " img1 = PIL.ImageOps.invert(img1)\n",
- " \n",
- " ## Transform \n",
- " if self.transform is not None:\n",
- " img0 = self.transform(img0)\n",
- " img1 = self.transform(img1)\n",
- " \n",
- " return img0, img1 , torch.from_numpy(np.array([int(img1_tuple[1]!=img0_tuple[1])],dtype=np.float32))\n",
- " \n",
- " def __len__(self):\n",
- " return len(self.imageFolderDataset.imgs)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Definition des Trainings- und Test-Folders"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 120,
- "metadata": {},
- "outputs": [],
- "source": [
- "## Folder for Train-Images\n",
- "folder_dataset_train = dset.ImageFolder(root=Config.path_OneShot_Train)\n",
- "\n",
- "## Folder for Test-Images\n",
- "folder_dataset_test = dset.ImageFolder(root=Config.path_OneShot_Test)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Definition des Transformers"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 121,
- "metadata": {},
- "outputs": [],
- "source": [
- "transformer = transforms.ToTensor()\n",
- "\n",
- "## Define Transformer for PyTorch-Networks or Siamese Network\n",
- "if(network_name == ('VGG11' or 'VGG11bn' or 'ResNet18' or 'ResNet34' or 'alexnet' or 'Squeezenet1_0' or 'GoogLeNet' or 'Shufflenet_v2_x0_5' or 'Resnext101_32x8d')):\n",
- " ## Transformer for PyTorch-Network\n",
- " transformer = transforms.Compose([\n",
- " transforms.Resize(256),\n",
- " transforms.CenterCrop(224),\n",
- " transforms.ToTensor(),\n",
- " transforms.Normalize((0.4318, 0.4660, 0.5889), (0.1752, 0.1893, 0.2096)),\n",
- "])\n",
- "\n",
- "elif(network_name == ('Siamese_Network')):\n",
- " ## Transformer for Siamese-Network\n",
- " transformer=transforms.Compose([transforms.Resize((100,100)),\n",
- " transforms.ToTensor()\n",
- " ])\n",
- "else:\n",
- " print('False Network choosen')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Laden von Train- und Testdatensätzen mit Transformation"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 122,
- "metadata": {},
- "outputs": [],
- "source": [
- "## For PyTorch-Networks\n",
- "siamese_dataset_train = SiameseDataset_Ears(imageFolderDataset=folder_dataset_train,\n",
- " transform=transformer,\n",
- " should_invert=False)\n",
- "\n",
- "## Create Test-Dataset\n",
- "siamese_dataset_test = SiameseDataset_Ears(imageFolderDataset=folder_dataset_test,\n",
- " transform=transformer,\n",
- " should_invert=True)\n",
- "\n",
- "\n",
- "\n",
- "## Create Training-DataLoader\n",
- "train_dataloader = DataLoader(siamese_dataset_train, batch_size=batch_size_train, shuffle=True,)\n",
- "\n",
- "## Create Test-DataLoader\n",
- "test_dataloader = DataLoader(siamese_dataset_test, batch_size=batch_size_test, shuffle=True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Visualisierung der Bildpaare des TrainLoaders\n",
- " \n",
- "Bildpaare: *Proband1 oben, Proband2 unten* \n",
- "0: Gleiches Paar \n",
- "1: Ungleiches Paar"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 123,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
- ]
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[[0.]\n",
- " [1.]\n",
- " [0.]\n",
- " [1.]\n",
- " [1.]\n",
- " [0.]\n",
- " [1.]\n",
- " [1.]]\n"
- ]
- }
- ],
- "source": [
- "vis_dataloader = DataLoader(siamese_dataset_train, shuffle=True, batch_size=8)\n",
- "\n",
- "dataiter = iter(vis_dataloader)\n",
- "\n",
- "#example_batch = next(dataiter)\n",
- "example_batch = dataiter.next()\n",
- "concatenated = torch.cat((example_batch[0],example_batch[1]),0)\n",
- "imshow(torchvision.utils.make_grid(concatenated))\n",
- "print(example_batch[2].numpy())"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Training:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 124,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Start Training um: 21:20:32\n",
- "Epoch number 0\n",
- " Current loss 308.1065979003906\n",
- "\n",
- "Epoch number 0\n",
- " Current loss 1.138925552368164\n",
- "\n",
- "Epoch number 0\n",
- " Current loss 0.964200496673584\n",
- "\n",
- "Epoch number 1\n",
- " Current loss 1.052968978881836\n",
- "\n",
- "Epoch number 1\n",
- " Current loss 0.9594541788101196\n",
- "\n",
- "Epoch number 1\n",
- " Current loss 0.4107712507247925\n",
- "\n",
- "Epoch number 2\n",
- " Current loss 1.5126802921295166\n",
- "\n",
- "Epoch number 2\n",
- " Current loss 1.127558946609497\n",
- "\n",
- "Epoch number 2\n",
- " Current loss 3.7104902267456055\n",
- "\n",
- "Epoch number 3\n",
- " Current loss 0.7310527563095093\n",
- "\n",
- "Epoch number 3\n",
- " Current loss 0.7123371362686157\n",
- "\n",
- "Epoch number 3\n",
- " Current loss 1.3404734134674072\n",
- "\n",
- "Epoch number 4\n",
- " Current loss 1.0322986841201782\n",
- "\n",
- "Epoch number 4\n",
- " Current loss 1.4310665130615234\n",
- "\n",
- "Epoch number 4\n",
- " Current loss 0.8324761390686035\n",
- "\n",
- "Epoch number 5\n",
- " Current loss 0.8502787351608276\n",
- "\n",
- "Epoch number 5\n",
- " Current loss 1.3122851848602295\n",
- "\n",
- "Epoch number 5\n",
- " Current loss 0.921744167804718\n",
- "\n",
- "Epoch number 6\n",
- " Current loss 1.1442584991455078\n",
- "\n",
- "Epoch number 6\n",
- " Current loss 1.011681318283081\n",
- "\n",
- "Epoch number 6\n",
- " Current loss 0.9135341048240662\n",
- "\n",
- "Epoch number 7\n",
- " Current loss 1.2509589195251465\n",
- "\n",
- "Epoch number 7\n",
- " Current loss 0.7557617425918579\n",
- "\n",
- "Epoch number 7\n",
- " Current loss 0.8743131160736084\n",
- "\n",
- "Epoch number 8\n",
- " Current loss 0.7078773975372314\n",
- "\n",
- "Epoch number 8\n",
- " Current loss 2.043950080871582\n",
- "\n",
- "Epoch number 8\n",
- " Current loss 1.0420727729797363\n",
- "\n",
- "Epoch number 9\n",
- " Current loss 0.9620883464813232\n",
- "\n",
- "Epoch number 9\n",
- " Current loss 0.9625381827354431\n",
- "\n",
- "Epoch number 9\n",
- " Current loss 1.1636301279067993\n",
- "\n",
- "Epoch number 10\n",
- " Current loss 1.5617934465408325\n",
- "\n",
- "Epoch number 10\n",
- " Current loss 1.029129147529602\n",
- "\n",
- "Epoch number 10\n",
- " Current loss 1.2446420192718506\n",
- "\n",
- "Epoch number 11\n",
- " Current loss 1.0907390117645264\n",
- "\n",
- "Epoch number 11\n",
- " Current loss 0.9430091977119446\n",
- "\n",
- "Epoch number 11\n",
- " Current loss 1.0374553203582764\n",
- "\n",
- "Epoch number 12\n",
- " Current loss 1.2946057319641113\n",
- "\n",
- "Epoch number 12\n",
- " Current loss 0.8857178092002869\n",
- "\n",
- "Epoch number 12\n",
- " Current loss 0.9587574005126953\n",
- "\n",
- "Epoch number 13\n",
- " Current loss 0.446734756231308\n",
- "\n",
- "Epoch number 13\n",
- " Current loss 0.9940989017486572\n",
- "\n",
- "Epoch number 13\n",
- " Current loss 0.9222986698150635\n",
- "\n",
- "Epoch number 14\n",
- " Current loss 0.8071630597114563\n",
- "\n",
- "Epoch number 14\n",
- " Current loss 0.9962530136108398\n",
- "\n",
- "Epoch number 14\n",
- " Current loss 0.9935855269432068\n",
- "\n",
- "Epoch number 15\n",
- " Current loss 1.2180464267730713\n",
- "\n",
- "Epoch number 15\n",
- " Current loss 1.0259199142456055\n",
- "\n",
- "Epoch number 15\n",
- " Current loss 0.8863251209259033\n",
- "\n",
- "Epoch number 16\n",
- " Current loss 1.0609502792358398\n",
- "\n",
- "Epoch number 16\n",
- " Current loss 0.9913638234138489\n",
- "\n",
- "Epoch number 16\n",
- " Current loss 1.1235753297805786\n",
- "\n",
- "Epoch number 17\n",
- " Current loss 1.1986947059631348\n",
- "\n",
- "Epoch number 17\n",
- " Current loss 0.7384413480758667\n",
- "\n",
- "Epoch number 17\n",
- " Current loss 0.8233901262283325\n",
- "\n",
- "Epoch number 18\n",
- " Current loss 1.227494716644287\n",
- "\n",
- "Epoch number 18\n",
- " Current loss 0.9361811876296997\n",
- "\n",
- "Epoch number 18\n",
- " Current loss 0.19040855765342712\n",
- "\n",
- "Epoch number 19\n",
- " Current loss 0.5063803195953369\n",
- "\n",
- "Epoch number 19\n",
- " Current loss 1.4189283847808838\n",
- "\n",
- "Epoch number 19\n",
- " Current loss 0.782526969909668\n",
- "\n",
- "Ende Training um: 21:21:26\n",
- "Dauer Training: 00:53 [MM:SS] \n",
- "\n"
- ]
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "epoch_range = 20\n",
- "counter = []\n",
- "loss_history = [] \n",
- "iteration_number= 0\n",
- "\n",
- "\n",
- "## Start Training\n",
- "print(\"Start Training um: \", time.strftime(\"%H:%M:%S\"))\n",
- "start_time = time.time()\n",
- "\n",
- "for epoch in range(epoch_range):\n",
- " for i, data in enumerate(train_dataloader, 0):\n",
- " img0, img1 , label = data\n",
- " img0, img1 , label = img0.to(DEVICE), img1.to(DEVICE), label.to(DEVICE)\n",
- "\n",
- " optimizer.zero_grad()\n",
- " \n",
- " ## Calculate Ouputs and LossConstrative\n",
- " if(network_name == ('VGG11' or 'VGG11bn' or 'ResNet18' or 'ResNet34' or 'alexnet' or 'Squeezenet1_0' or 'GoogLeNet' or 'Shufflenet_v2_x0_5' or 'Resnext101_32x8d')):\n",
- " output1 = network(img0)\n",
- " output2 = network(img1)\n",
- " loss_contrastive = criterion(output1,output2,label)\n",
- " \n",
- " elif(network_name == ('Siamese_Network')):\n",
- " output1, output2 = network(img0, img1)\n",
- " #outputs = network(img0, img1)\n",
- " loss_contrastive = criterion(output1,output2,label)\n",
- " else:\n",
- " print('False Network choosen')\n",
- " \n",
- " \n",
- " ## Backpropagation and Optmizer\n",
- " loss_contrastive.backward()\n",
- " optimizer.step()\n",
- " \n",
- " ## Get loss\n",
- " if i %10 == 0 :\n",
- " print(\"Epoch number {}\\n Current loss {}\\n\".format(epoch,loss_contrastive.item()))\n",
- " iteration_number += 10\n",
- " counter.append(iteration_number)\n",
- " loss_history.append(loss_contrastive.item())\n",
- "\n",
- " \n",
- "## Finished Training\n",
- "print(\"Ende Training um: \",time.strftime(\"%H:%M:%S\"))\n",
- "stop_time = time.time()\n",
- "time_dif, time_format = secs_to_HMS(stop_time-start_time)\n",
- "print(\"Dauer Training: \", time_dif, \" \", time_format, \" \\n\") \n",
- "\n",
- "show_plot(counter,loss_history)\n",
- "WaitTime_Finished()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 63,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "show_plot(counter,loss_history)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Speichern der trainierten Netzwerke"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 125,
- "metadata": {},
- "outputs": [],
- "source": [
- "## Path to save Network \n",
- "PATH = './Netzwerke_OneShot/' + current_time + '_' + Database + '_' + network_name + '_Train' + '.pth'\n",
- "\n",
- "## Save Network\n",
- "torch.save({\n",
- " 'epoch_range': epoch_range,\n",
- " 'model_state_dict': network.state_dict(),\n",
- " 'optimizer_state_dict': optimizer.state_dict(),\n",
- " 'loss': loss_contrastive\n",
- " }, PATH)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Testen"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 126,
- "metadata": {},
- "outputs": [],
- "source": [
- "folder_dataset_test = dset.ImageFolder(root=Config.path_OneShot_Test)\n",
- "\n",
- "siamese_dataset = SiameseDataset_Ears(imageFolderDataset=folder_dataset_test,\n",
- " transform=transformer\n",
- " ,should_invert=False)\n",
- "\n",
- "test_dataloader = DataLoader(siamese_dataset,batch_size=1,shuffle=True)\n",
- "\n",
- "## Test Function for Siamese-Network\n",
- "def evaluation(model, test_loader):\n",
- " with torch.no_grad():\n",
- " model.eval()\n",
- " correct = 0\n",
- " count = 0\n",
- "\n",
- " for mainImg, imgSets, label in test_loader:\n",
- " mainImg, imgSets, label = mainImg.to(DEVICE), imgSets.to(DEVICE), label.to(DEVICE)\n",
- " predVal = 2.1\n",
- " pred = -1\n",
- " count += 1\n",
- " \n",
- " ## Determine which category an image belongs to\n",
- " for i, testImg in enumerate(imgSets):\n",
- " testImg = testImg.to(DEVICE)\n",
- "\n",
- " if(network_name == ('VGG11' or 'VGG11bn' or 'ResNet18' or 'ResNet34' or 'alexnet' or 'Squeezenet1_0' or 'GoogLeNet' or 'Shufflenet_v2_x0_5' or 'Resnext101_32x8d')):\n",
- " output1 = network(mainImg)\n",
- " output2 = network(imgSets)\n",
- " \n",
- " elif(network_name == ('Siamese_Network')):\n",
- " output1, output2 = network(mainImg, imgSets)\n",
- " \n",
- " else:\n",
- " print('False Network choosen')\n",
- "\n",
- " \n",
- " euclidean_distance = F.pairwise_distance(output1, output2)\n",
- " euclidean_distance = euclidean_distance.cpu().numpy()\n",
- " \n",
- "\n",
- " if(((euclidean_distance < predVal) and (label==0)) or ((euclidean_distance > predVal) and (label==1))):\n",
- " correct += 1\n",
- " print('Accuracy: {}'.format(correct/count))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 127,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Accuracy: 1.0\n",
- "Accuracy: 0.5\n",
- "Accuracy: 0.3333333333333333\n",
- "Accuracy: 0.25\n",
- "Accuracy: 0.4\n",
- "Accuracy: 0.3333333333333333\n",
- "Accuracy: 0.2857142857142857\n",
- "Accuracy: 0.375\n",
- "Accuracy: 0.3333333333333333\n",
- "Accuracy: 0.3\n",
- "Accuracy: 0.2727272727272727\n",
- "Accuracy: 0.25\n",
- "Accuracy: 0.3076923076923077\n",
- "Accuracy: 0.35714285714285715\n",
- "Accuracy: 0.4\n",
- "Accuracy: 0.375\n",
- "Accuracy: 0.35294117647058826\n",
- "Accuracy: 0.3888888888888889\n",
- "Accuracy: 0.42105263157894735\n",
- "Accuracy: 0.45\n",
- "Accuracy: 0.42857142857142855\n",
- "Accuracy: 0.4090909090909091\n",
- "Accuracy: 0.391304347826087\n",
- "Accuracy: 0.375\n",
- "Accuracy: 0.4\n",
- "Accuracy: 0.38461538461538464\n",
- "Accuracy: 0.37037037037037035\n",
- "Accuracy: 0.39285714285714285\n",
- "Accuracy: 0.3793103448275862\n",
- "Accuracy: 0.36666666666666664\n",
- "Accuracy: 0.3870967741935484\n",
- "Accuracy: 0.375\n",
- "Accuracy: 0.3939393939393939\n",
- "Accuracy: 0.4117647058823529\n",
- "Accuracy: 0.42857142857142855\n",
- "Accuracy: 0.4166666666666667\n",
- "Accuracy: 0.40540540540540543\n",
- "Accuracy: 0.42105263157894735\n",
- "Accuracy: 0.4358974358974359\n",
- "Accuracy: 0.45\n"
- ]
- }
- ],
- "source": [
- "evaluation(network, test_dataloader)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
- }
|