Masterarbeit zur Untersuchung des Ohrs zur Personenauthntifizierung an IT-Systemen mittels CNNs
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

OneShot_Ears.ipynb 94KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# One Shot Learning\n",
  8. "\n",
  9. "Mit diesem Jupyter-Skript werden alle notwendigen Funktionen zur Umsetzung und Auswertung des OneShot-Learnings (Siamese Networks) von Ohrbildern zusammengefasst. \n",
  10. "Das Jupyter-Skript bezieht sich auf die Masterarbeit: \"Verwendung des menschlichen Ohrs zur Personenauthentifizierung an IT-Systemen mittels CNNs\"."
  11. ]
  12. },
  13. {
  14. "cell_type": "markdown",
  15. "metadata": {},
  16. "source": [
  17. "## Bibliotheken importieren\n",
  18. "\n",
  19. "Zunächst werden alle notwendigen Bibliothken importiert"
  20. ]
  21. },
  22. {
  23. "cell_type": "code",
  24. "execution_count": 105,
  25. "metadata": {},
  26. "outputs": [],
  27. "source": [
  28. "## Import Libearies ##\n",
  29. "%matplotlib inline\n",
  30. "import cv2\n",
  31. "from PIL import Image\n",
  32. "import PIL.ImageOps \n",
  33. "import numpy as np\n",
  34. "import matplotlib.pyplot as plt\n",
  35. "import matplotlib.image as mpimg\n",
  36. "import pandas as pd\n",
  37. "import os\n",
  38. "\n",
  39. "## Import TORCHVISION with popular datasets, model architectures, \n",
  40. "## and common image transformations for computer vision\n",
  41. "from torchvision import transforms\n",
  42. "import torchvision\n",
  43. "import torchvision.utils\n",
  44. "import torchvision.models as models\n",
  45. "import torchvision.datasets as dset\n",
  46. "import torchsummary\n",
  47. "\n",
  48. "# Import Debugging method for set_trace\n",
  49. "from IPython.core.debugger import set_trace\n",
  50. "import logging\n",
  51. "\n",
  52. "## Import Time Features\n",
  53. "import datetime\n",
  54. "import time\n",
  55. "\n",
  56. "## Import TORCH for Data-Structures for multi-dimensional tensors \n",
  57. "## and mathematical operations\n",
  58. "import torch\n",
  59. "from torch.utils.tensorboard import SummaryWriter\n",
  60. "import torch.nn as nn\n",
  61. "import torch.nn.functional as F\n",
  62. "from torch.utils.data import DataLoader, Dataset\n",
  63. "import torch.optim as optim\n",
  64. "import torch_utils\n",
  65. "from torch.autograd import Variable\n",
  66. "\n",
  67. "# Import Winsound to play a *.wav-File\n",
  68. "# and Counter for counting numbers in an array\n",
  69. "import winsound\n",
  70. "from collections import Counter\n",
  71. "\n",
  72. "## Import Shutil to copy or delete Files\n",
  73. "import shutil\n",
  74. "from shutil import copyfile\n",
  75. "\n",
  76. "## Import Random Libery\n",
  77. "import random\n",
  78. "from random import shuffle\n",
  79. "\n",
  80. "## Import IPYWIDGET for interactive HTML widgets for Jupyter notebooks\n",
  81. "import ipywidgets as wg\n",
  82. "from IPython.display import display"
  83. ]
  84. },
  85. {
  86. "cell_type": "markdown",
  87. "metadata": {},
  88. "source": [
  89. "## Festlegung: Prozessor oder Graphikkarte"
  90. ]
  91. },
  92. {
  93. "cell_type": "code",
  94. "execution_count": 106,
  95. "metadata": {},
  96. "outputs": [
  97. {
  98. "name": "stdout",
  99. "output_type": "stream",
  100. "text": [
  101. "cuda\n"
  102. ]
  103. }
  104. ],
  105. "source": [
  106. "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
  107. "print(DEVICE)"
  108. ]
  109. },
  110. {
  111. "cell_type": "markdown",
  112. "metadata": {},
  113. "source": [
  114. "## Funktion zur Gegenüberstellung zwei Ohren\n",
  115. "\n",
  116. "Zwei Ohren werden gegenübergestellt und dazu der Constative-Loss Wert angezeigt"
  117. ]
  118. },
  119. {
  120. "cell_type": "code",
  121. "execution_count": 107,
  122. "metadata": {},
  123. "outputs": [],
  124. "source": [
  125. "# build a rectangle in axes coords\n",
  126. "left, width = 5, .5\n",
  127. "bottom, height = .25, .5\n",
  128. "right = left + width\n",
  129. "top = bottom + height\n",
  130. "\n",
  131. "## Function to show Ear images and the Constrative Loss\n",
  132. "def imshow(img,text=None,should_save=False):\n",
  133. " npimg = img.cpu().numpy()\n",
  134. " plt.axis(\"off\")\n",
  135. " if text:\n",
  136. " plt.text(75, 8, text, style='italic',fontweight='bold',\n",
  137. " bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})\n",
  138. " plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
  139. " plt.show() \n",
  140. "\n",
  141. "## Function to show Learning-Curve with matplotlib\n",
  142. "def show_plot(iteration,loss):\n",
  143. " plt.plot(iteration,loss)\n",
  144. " plt.show()"
  145. ]
  146. },
  147. {
  148. "cell_type": "markdown",
  149. "metadata": {},
  150. "source": [
  151. "## Datensätze festlegen, anzeigen und einlesen"
  152. ]
  153. },
  154. {
  155. "cell_type": "code",
  156. "execution_count": 108,
  157. "metadata": {},
  158. "outputs": [],
  159. "source": [
  160. "## Indirectr Path for Linux and Windows## \n",
  161. "def ChooseDir(datadir):\n",
  162. " if(datadir == 'CP'):\n",
  163. " return \"Datensaetze/CP\", \"CP\"\n",
  164. " if(datadir == 'AMI'):\n",
  165. " return \"Datensaetze/AMI\", \"AMI\"\n",
  166. " if(datadir == 'AWE'):\n",
  167. " return \"Datensaetze/AWE\", \"AWE\"\n",
  168. " if(datadir == 'EarVN'):\n",
  169. " return \"Datensaetze/EarVN_1_0\", \"EarVN_1_0\"\n",
  170. " if(datadir == 'UERC'):\n",
  171. " return \"Datensaetze/UERC\", \"UERC\""
  172. ]
  173. },
  174. {
  175. "cell_type": "markdown",
  176. "metadata": {},
  177. "source": [
  178. "## Voreinstellungen - Hilfsfunktionen\n",
  179. "\n",
  180. "* **secs_to_HMS():** Berechnte Sekunden in HH:MM:SS um\n",
  181. "* **WaitTime_Finished():** Gibt eine Audio-Datei aus, wenn Training und Testen beendet ist"
  182. ]
  183. },
  184. {
  185. "cell_type": "code",
  186. "execution_count": 109,
  187. "metadata": {},
  188. "outputs": [],
  189. "source": [
  190. "## Convert Seconds to Hours, Minutes and Seconds ##\n",
  191. "def secs_to_HMS(secs):\n",
  192. " if (secs < 3600):\n",
  193. " return datetime.datetime.fromtimestamp(secs).strftime('%M:%S'), \"[MM:SS]\"\n",
  194. " else:\n",
  195. " return datetime.datetime.fromtimestamp(secs).strftime('%H:%M:%S'), \"[HH:MM:SS]\"\n",
  196. "\n",
  197. "## Acustic Sound if Model-Trainings finished ##\n",
  198. "def WaitTime_Finished():\n",
  199. " for i in range(2):\n",
  200. " files=os.listdir(\"Sound/\")\n",
  201. " file=random.choice(files)\n",
  202. " winsound.PlaySound(\"Sound/\"+str(file), winsound.SND_FILENAME)\n",
  203. " time.sleep(2)"
  204. ]
  205. },
  206. {
  207. "cell_type": "code",
  208. "execution_count": 110,
  209. "metadata": {},
  210. "outputs": [
  211. {
  212. "name": "stdout",
  213. "output_type": "stream",
  214. "text": [
  215. "2021_02_27-21_19_36\n"
  216. ]
  217. }
  218. ],
  219. "source": [
  220. "## Use actual Time for individual Tag ##\n",
  221. "current_time = datetime.datetime.now().strftime(\"%Y_%m_%d-%H_%M_%S\")\n",
  222. "print(current_time)"
  223. ]
  224. },
  225. {
  226. "cell_type": "code",
  227. "execution_count": 111,
  228. "metadata": {},
  229. "outputs": [],
  230. "source": [
  231. "## All Networks from PyTorch wich can be pretrained ## \n",
  232. "class SiameseNetwork(nn.Module):\n",
  233. " def __init__(self):\n",
  234. " super(SiameseNetwork, self).__init__()\n",
  235. " self.cnn1 = nn.Sequential(\n",
  236. " nn.ReflectionPad2d(1),\n",
  237. " nn.Conv2d(1, 4, kernel_size=3),\n",
  238. " nn.ReLU(inplace=True),\n",
  239. " nn.BatchNorm2d(4),\n",
  240. " \n",
  241. " nn.ReflectionPad2d(1),\n",
  242. " nn.Conv2d(4, 8, kernel_size=3),\n",
  243. " nn.ReLU(inplace=True),\n",
  244. " nn.BatchNorm2d(8),\n",
  245. "\n",
  246. "\n",
  247. " nn.ReflectionPad2d(1),\n",
  248. " nn.Conv2d(8, 8, kernel_size=3),\n",
  249. " nn.ReLU(inplace=True),\n",
  250. " nn.BatchNorm2d(8),\n",
  251. " )\n",
  252. "\n",
  253. " self.fc1 = nn.Sequential(\n",
  254. " nn.Linear(8*100*100, 500),\n",
  255. " nn.ReLU(inplace=True),\n",
  256. "\n",
  257. " nn.Linear(500, 500),\n",
  258. " nn.ReLU(inplace=True),\n",
  259. "\n",
  260. " nn.Linear(500, 5))\n",
  261. " def forward_once(self, x):\n",
  262. " output = self.cnn1(x)\n",
  263. " output = output.view(output.size()[0], -1)\n",
  264. " output = self.fc1(output)\n",
  265. " return output\n",
  266. " def forward(self, input1, input2):\n",
  267. " output1 = self.forward_once(input1)\n",
  268. " output2 = self.forward_once(input2)\n",
  269. " return output1, output2\n",
  270. "\n",
  271. "\n",
  272. "def Network_Choice(netw): \n",
  273. " if(netw == 'vgg11'):\n",
  274. " return models.vgg11(pretrained=True).to(DEVICE), \"VGG11\"\n",
  275. " if(netw == 'vgg11bn'):\n",
  276. " return models.vgg11_bn(pretrained=True).to(DEVICE), \"VGG11bn\"\n",
  277. " if(netw == 'resnet18'):\n",
  278. " return models.resnet18(pretrained=True).to(DEVICE), \"ResNet18\"\n",
  279. " if(netw == 'resnet34'):\n",
  280. " return models.resnet34(pretrained=True).to(DEVICE), \"ResNet34\"\n",
  281. " if(netw == 'alexnet'):\n",
  282. " return models.alexnet(pretrained=True).to(DEVICE), \"AlexNet\"\n",
  283. " if(netw == 'squeezenet1_0'):\n",
  284. " return models.squeezenet1_0(pretrained=True).to(DEVICE), \"SqueezeNet-1-0\" \n",
  285. " if(netw == 'GoogLeNet'):\n",
  286. " return models.googlenet(pretrained=True).to(DEVICE), \"GoogLeNet\" \n",
  287. " if(netw == 'shufflenet_v2_x0_5'):\n",
  288. " return models.shufflenet_v2_x0_5(pretrained=True).to(DEVICE), \"Shufflenet-v2-x0_5\" \n",
  289. " if(netw == 'resnext101_32x8d'):\n",
  290. " return models.resnext101_32x8d(pretrained=True).to(DEVICE), \"Resnext101-32x8d\"\n",
  291. " if(netw == 'siamese'):\n",
  292. " return SiameseNetwork().to(DEVICE), 'Siamese_Network'"
  293. ]
  294. },
  295. {
  296. "cell_type": "markdown",
  297. "metadata": {},
  298. "source": [
  299. "## Contrastive Loss\n",
  300. "\n",
  301. "Constrive-Loss Berechnung für ein Siamese-Network "
  302. ]
  303. },
  304. {
  305. "cell_type": "code",
  306. "execution_count": 112,
  307. "metadata": {},
  308. "outputs": [],
  309. "source": [
  310. "class ContrastiveLoss(torch.nn.Module):\n",
  311. " def __init__(self, margin=2.0):\n",
  312. " super(ContrastiveLoss, self).__init__()\n",
  313. " self.margin = margin\n",
  314. "\n",
  315. " def forward(self, output1, output2, label):\n",
  316. " euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)\n",
  317. " loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +\n",
  318. " (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))\n",
  319. " return loss_contrastive"
  320. ]
  321. },
  322. {
  323. "cell_type": "markdown",
  324. "metadata": {},
  325. "source": [
  326. "## Konfiguration\n",
  327. "\n",
  328. "Auswahl von: \n",
  329. "* Datensatz\n",
  330. "* Netzwerk\n",
  331. "* Trainshare\n",
  332. "* Batch Size Train\n",
  333. "* Batch Size Test\n",
  334. "* Learning Rate\n",
  335. "* Momentum\n",
  336. "\n",
  337. "Hinweis: Konfiguration ausführen, Auswahl festlegen und weiter zum nächsten Abschnitt"
  338. ]
  339. },
  340. {
  341. "cell_type": "code",
  342. "execution_count": 113,
  343. "metadata": {},
  344. "outputs": [
  345. {
  346. "data": {
  347. "application/vnd.jupyter.widget-view+json": {
  348. "model_id": "97cf4f37c7ed4036a2879dbb139818c5",
  349. "version_major": 2,
  350. "version_minor": 0
  351. },
  352. "text/plain": [
  353. "Dropdown(description='Dataset:', index=1, options=('CP', 'AMI', 'AWE', 'EarVN', 'UERC'), value='AMI')"
  354. ]
  355. },
  356. "metadata": {},
  357. "output_type": "display_data"
  358. },
  359. {
  360. "data": {
  361. "application/vnd.jupyter.widget-view+json": {
  362. "model_id": "b5002f1761e64a5bbbaec4ecd034c112",
  363. "version_major": 2,
  364. "version_minor": 0
  365. },
  366. "text/plain": [
  367. "Dropdown(description='Network:', options=('vgg11', 'vgg11bn', 'resnet18', 'resnet34', 'alexnet', 'squeezenet1_…"
  368. ]
  369. },
  370. "metadata": {},
  371. "output_type": "display_data"
  372. },
  373. {
  374. "data": {
  375. "application/vnd.jupyter.widget-view+json": {
  376. "model_id": "0eb51fd4b5f744e98edad81c104e7f43",
  377. "version_major": 2,
  378. "version_minor": 0
  379. },
  380. "text/plain": [
  381. "BoundedFloatText(value=0.8, description='Trainshare:', max=1.0, step=0.1)"
  382. ]
  383. },
  384. "metadata": {},
  385. "output_type": "display_data"
  386. },
  387. {
  388. "data": {
  389. "application/vnd.jupyter.widget-view+json": {
  390. "model_id": "48e119312d5e429ea553518c6d3df1aa",
  391. "version_major": 2,
  392. "version_minor": 0
  393. },
  394. "text/plain": [
  395. "BoundedIntText(value=4, description='Batch_Train:', min=1)"
  396. ]
  397. },
  398. "metadata": {},
  399. "output_type": "display_data"
  400. },
  401. {
  402. "data": {
  403. "application/vnd.jupyter.widget-view+json": {
  404. "model_id": "21770422b0304a63907e69a750f6bea2",
  405. "version_major": 2,
  406. "version_minor": 0
  407. },
  408. "text/plain": [
  409. "BoundedIntText(value=4, description='Batch_Test:', min=1)"
  410. ]
  411. },
  412. "metadata": {},
  413. "output_type": "display_data"
  414. },
  415. {
  416. "data": {
  417. "application/vnd.jupyter.widget-view+json": {
  418. "model_id": "14af8e2f1cc0414d8ce34c0bd93ddd7c",
  419. "version_major": 2,
  420. "version_minor": 0
  421. },
  422. "text/plain": [
  423. "BoundedFloatText(value=0.001, description='Learning_Rate', max=1.0, step=0.0001)"
  424. ]
  425. },
  426. "metadata": {},
  427. "output_type": "display_data"
  428. },
  429. {
  430. "data": {
  431. "application/vnd.jupyter.widget-view+json": {
  432. "model_id": "d8b9466f0d994762bd04d858d213ce0a",
  433. "version_major": 2,
  434. "version_minor": 0
  435. },
  436. "text/plain": [
  437. "BoundedFloatText(value=0.8, description='Momentum: ', max=1.0, step=0.1)"
  438. ]
  439. },
  440. "metadata": {},
  441. "output_type": "display_data"
  442. }
  443. ],
  444. "source": [
  445. "datadir_choose = wg.Dropdown(\n",
  446. " options=['CP', 'AMI', 'AWE', 'EarVN', 'UERC'],\n",
  447. " value='AMI',\n",
  448. " description='Dataset:',\n",
  449. " disabled=False,\n",
  450. " button_style=''\n",
  451. ")\n",
  452. "\n",
  453. "network_string_choose = wg.Dropdown(\n",
  454. " options=['vgg11','vgg11bn','resnet18','resnet34','alexnet','squeezenet1_0','GoogLeNet','shufflenet_v2_x0_5','resnext101_32x8d', 'siamese'],\n",
  455. " value='vgg11',\n",
  456. " description='Network:',\n",
  457. " disabled=False,\n",
  458. " button_style=''\n",
  459. ")\n",
  460. "\n",
  461. "trainshare_choose = wg.BoundedFloatText(\n",
  462. " value=0.8,\n",
  463. " min=0,\n",
  464. " max=1,\n",
  465. " step=0.1,\n",
  466. " description='Trainshare:',\n",
  467. " disabled=False\n",
  468. ")\n",
  469. "\n",
  470. "batch_train = wg.BoundedIntText(\n",
  471. " value=4,\n",
  472. " min=1,\n",
  473. " max=100,\n",
  474. " step=1,\n",
  475. " description='Batch_Train:',\n",
  476. " disabled=False\n",
  477. ")\n",
  478. "\n",
  479. "batch_test = wg.BoundedIntText(\n",
  480. " value=4,\n",
  481. " min=1,\n",
  482. " max=100,\n",
  483. " step=1,\n",
  484. " description='Batch_Test:',\n",
  485. " disabled=False\n",
  486. ")\n",
  487. "\n",
  488. "learning_rate_choose = wg.BoundedFloatText(\n",
  489. " value=0.001,\n",
  490. " min=0,\n",
  491. " max=1,\n",
  492. " step=0.0001,\n",
  493. " description='Learning_Rate',\n",
  494. " disabled=False\n",
  495. ")\n",
  496. "\n",
  497. "momentum_choose = wg.BoundedFloatText(\n",
  498. " value=0.8,\n",
  499. " min=0,\n",
  500. " max=1,\n",
  501. " step=0.1,\n",
  502. " description='Momentum: ',\n",
  503. " disabled=False\n",
  504. ")\n",
  505. "\n",
  506. "\n",
  507. "display(datadir_choose, network_string_choose, trainshare_choose, batch_train, batch_test, learning_rate_choose, momentum_choose)"
  508. ]
  509. },
  510. {
  511. "cell_type": "code",
  512. "execution_count": 114,
  513. "metadata": {},
  514. "outputs": [],
  515. "source": [
  516. "## Define Variables, Arrays for Dataset and Categories ##\n",
  517. "dataset_train = []\n",
  518. "dataset_test = []\n",
  519. "CATEGORIES = []\n",
  520. "#img_array = []\n",
  521. "network_name = 'nix'\n",
  522. "\n",
  523. "datadir = datadir_choose.value\n",
  524. "network_string = network_string_choose.value\n",
  525. "train_share = trainshare_choose.value\n",
  526. "batch_size_train = batch_train.value\n",
  527. "batch_size_test = batch_test.value\n",
  528. "learning_rate = learning_rate_choose.value\n",
  529. "momentum = momentum_choose.value"
  530. ]
  531. },
  532. {
  533. "cell_type": "markdown",
  534. "metadata": {},
  535. "source": [
  536. "## Pfade zum OneShot-Learning\n",
  537. "\n",
  538. "Pfade zu den AMI-Bildern für das Trainieren und Testen der Siamese/PyTorch-Netzwerke"
  539. ]
  540. },
  541. {
  542. "cell_type": "code",
  543. "execution_count": 115,
  544. "metadata": {},
  545. "outputs": [],
  546. "source": [
  547. "class Config():\n",
  548. " ## Define Pathes for Train- and Test-Images\n",
  549. " path_OneShot_Train = \"./Datensaetze/AMI_OneShot/OneShot_Train/\"\n",
  550. " path_OneShot_Test = \"./Datensaetze/AMI_OneShot/OneShot_Testing/\" "
  551. ]
  552. },
  553. {
  554. "cell_type": "markdown",
  555. "metadata": {},
  556. "source": [
  557. "## Netzwerke und Datensätze einlesen"
  558. ]
  559. },
  560. {
  561. "cell_type": "code",
  562. "execution_count": 116,
  563. "metadata": {},
  564. "outputs": [],
  565. "source": [
  566. "## Get Network and Network-Name\n",
  567. "network, network_name = Network_Choice(network_string)\n",
  568. "## Define Constrative-Loss as Loss-Function\n",
  569. "criterion = ContrastiveLoss()\n",
  570. "## Define Adam as optimizer\n",
  571. "optimizer = optim.Adam(network.parameters(),lr = 0.0005 )\n",
  572. "\n",
  573. "## Choose Datadirectory to the Dataset ##\n",
  574. "DATADIR, Database = ChooseDir(datadir)\n",
  575. "## Read and list all Categories of a Dataset \n",
  576. "CATEGORIES = os.listdir(DATADIR)"
  577. ]
  578. },
  579. {
  580. "cell_type": "markdown",
  581. "metadata": {},
  582. "source": [
  583. "## Erstellung oder Löschen des Dataset-Folder für SiameseDataset_Ear\n",
  584. "\n",
  585. "* **Create_OneShot_File_Images():**\n",
  586. "Falls keine Bilder in *./Datensaetze/AMI_OneShot_Train'* oder *./Datensaetze/AMI_OneShot_Test'* vorhanden sind können diese mit *Create_OneShot_File_Images()*. \n",
  587. "\n",
  588. "* **Delete_OneShot_File_Images():**\n",
  589. "Falls andere Datensätze als der AMI-Datensatz gewünscht sind, kann der AMI-Datensatz durch *Delete_OneShot_File_Images():* gelöscht werden "
  590. ]
  591. },
  592. {
  593. "cell_type": "code",
  594. "execution_count": 117,
  595. "metadata": {},
  596. "outputs": [],
  597. "source": [
  598. "## Create Dataset-Images for OneShot Learning\n",
  599. "def Create_OneShot_File_Images():\n",
  600. " for category in CATEGORIES:\n",
  601. " path = os.path.join(DATADIR, category)\n",
  602. " class_num = CATEGORIES.index(category)+1\n",
  603. " count_train_share = (len(os.listdir(path)))*train_share\n",
  604. " counter = 1\n",
  605. " os_listdir = os.listdir(path)\n",
  606. " try:\n",
  607. " ## Create Folders in Train- and Test-Directories\n",
  608. " os.mkdir(Config.path_OneShot_Train+str(class_num))\n",
  609. " os.mkdir(Config.path_OneShot_Test+str(class_num))\n",
  610. " except OSError:\n",
  611. " print (\"Creation of the directory %s failed\" % Config.path_OneShot_Train)\n",
  612. " else:\n",
  613. " print (\"Successfully created the directory %s \" % Config.path_OneShot_Test)\n",
  614. "\n",
  615. " for img in os_listdir:\n",
  616. " try:\n",
  617. " ## Copy all Images in Train- and Test-Directories\n",
  618. " if(counter <= count_train_share):\n",
  619. " shutil.copy(path+'/'+img, Config.path_OneShot_Train+str(class_num))\n",
  620. " counter += 1\n",
  621. " else:\n",
  622. " shutil.copy(path+'/'+img, Config.path_OneShot_Test+str(class_num))\n",
  623. " except Exception as e:\n",
  624. " pass\n",
  625. " \n",
  626. "## Delete Dataset-Images for OneShot Learning \n",
  627. "def Delete_OneShot_File_Images():\n",
  628. " for category in CATEGORIES:\n",
  629. " class_num = CATEGORIES.index(category)#1\n",
  630. " count_train_share = (len(os.listdir(path)))*train_share\n",
  631. " counter = 1\n",
  632. " os_listdir = os.listdir(path)\n",
  633. " ## Delete all Directories and Images of One Shot Learning\n",
  634. " try:\n",
  635. " shutil.rmtree(Config.path_OneShot_Train+str(class_num))\n",
  636. " shutil.rmtree(Config.path_OneShot_Test+str(class_num))\n",
  637. " except OSError:\n",
  638. " print (\"Deletion of the directory %s failed\" % Config.path_OneShot_Train)\n",
  639. " else:\n",
  640. " print (\"Successfully deleted the directory %s\" % Config.path_OneShot_Train)"
  641. ]
  642. },
  643. {
  644. "cell_type": "code",
  645. "execution_count": 118,
  646. "metadata": {},
  647. "outputs": [
  648. {
  649. "name": "stdout",
  650. "output_type": "stream",
  651. "text": [
  652. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  653. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  654. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  655. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  656. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  657. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  658. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  659. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  660. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  661. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  662. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  663. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  664. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  665. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  666. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  667. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  668. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  669. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  670. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n",
  671. "Successfully created the directory ./Datensaetze/AMI_OneShot/OneShot_Testing/ \n"
  672. ]
  673. }
  674. ],
  675. "source": [
  676. "#Create_OneShot_File_Images()\n",
  677. "#Delete_OneShot_File_Images()"
  678. ]
  679. },
  680. {
  681. "cell_type": "markdown",
  682. "metadata": {},
  683. "source": [
  684. "## Klasse zur Erstellung der Bildpaare\n",
  685. "\n",
  686. "* **SiameseDataset_Ears()**: Erstellt die identische Anzahl an gleichen und ungleichen Bildpaaren und transformiert die Bilder"
  687. ]
  688. },
  689. {
  690. "cell_type": "code",
  691. "execution_count": 119,
  692. "metadata": {},
  693. "outputs": [],
  694. "source": [
  695. "## Create Dataset of Ear-Images for OneShot-Learning\n",
  696. "class SiameseDataset_Ears(Dataset):\n",
  697. " \n",
  698. " def __init__(self,imageFolderDataset,transform=None,should_invert=True):\n",
  699. " self.imageFolderDataset = imageFolderDataset \n",
  700. " self.transform = transform\n",
  701. " self.should_invert = should_invert\n",
  702. " \n",
  703. " def __getitem__(self,index):\n",
  704. " ## Get random Picture from Test-File\n",
  705. " img0_tuple = random.choice(self.imageFolderDataset.imgs)\n",
  706. " \n",
  707. " #we need to make sure approx 50% of images are in the same class\n",
  708. " should_get_same_class = random.randint(0,1) \n",
  709. " if should_get_same_class:\n",
  710. " while True:\n",
  711. " #keep looping till the same class image is found\n",
  712. " img1_tuple = random.choice(self.imageFolderDataset.imgs) \n",
  713. " if img0_tuple[1]==img1_tuple[1]:\n",
  714. " break\n",
  715. " else:\n",
  716. " while True:\n",
  717. " #keep looping till a different class image is found \n",
  718. " img1_tuple = random.choice(self.imageFolderDataset.imgs) \n",
  719. " if img0_tuple[1] !=img1_tuple[1]:\n",
  720. " break\n",
  721. "\n",
  722. " ## Load images and convert to RGB if PyTorch-Network is Choosen\n",
  723. " ## Load images and convert to Gray if Siamese-Network is Choosen\n",
  724. " if(network_name == ('VGG11' or 'VGG11bn' or 'ResNet18' or 'ResNet34' or 'alexnet' or 'Squeezenet1_0' or 'GoogLeNet' or 'Shufflenet_v2_x0_5' or 'Resnext101_32x8d')):\n",
  725. " img0 = Image.open(img0_tuple[0]).convert('RGB')\n",
  726. " img1 = Image.open(img1_tuple[0]).convert('RGB') \n",
  727. " elif(network_name == ('Siamese_Network')):\n",
  728. " img0 = Image.open(img0_tuple[0]).convert(\"L\")\n",
  729. " img1 = Image.open(img1_tuple[0]).convert(\"L\")\n",
  730. " #img0 = img0.convert(\"L\")\n",
  731. " #img1 = img1.convert(\"L\")\n",
  732. " else:\n",
  733. " print('False Network choosen') \n",
  734. " \n",
  735. " ## Invert loaded PIL images\n",
  736. " if self.should_invert:\n",
  737. " img0 = PIL.ImageOps.invert(img0)\n",
  738. " img1 = PIL.ImageOps.invert(img1)\n",
  739. " \n",
  740. " ## Transform \n",
  741. " if self.transform is not None:\n",
  742. " img0 = self.transform(img0)\n",
  743. " img1 = self.transform(img1)\n",
  744. " \n",
  745. " return img0, img1 , torch.from_numpy(np.array([int(img1_tuple[1]!=img0_tuple[1])],dtype=np.float32))\n",
  746. " \n",
  747. " def __len__(self):\n",
  748. " return len(self.imageFolderDataset.imgs)"
  749. ]
  750. },
  751. {
  752. "cell_type": "markdown",
  753. "metadata": {},
  754. "source": [
  755. "## Definition des Trainings- und Test-Folders"
  756. ]
  757. },
  758. {
  759. "cell_type": "code",
  760. "execution_count": 120,
  761. "metadata": {},
  762. "outputs": [],
  763. "source": [
  764. "## Folder for Train-Images\n",
  765. "folder_dataset_train = dset.ImageFolder(root=Config.path_OneShot_Train)\n",
  766. "\n",
  767. "## Folder for Test-Images\n",
  768. "folder_dataset_test = dset.ImageFolder(root=Config.path_OneShot_Test)"
  769. ]
  770. },
  771. {
  772. "cell_type": "markdown",
  773. "metadata": {},
  774. "source": [
  775. "## Definition des Transformers"
  776. ]
  777. },
  778. {
  779. "cell_type": "code",
  780. "execution_count": 121,
  781. "metadata": {},
  782. "outputs": [],
  783. "source": [
  784. "transformer = transforms.ToTensor()\n",
  785. "\n",
  786. "## Define Transformer for PyTorch-Networks or Siamese Network\n",
  787. "if(network_name == ('VGG11' or 'VGG11bn' or 'ResNet18' or 'ResNet34' or 'alexnet' or 'Squeezenet1_0' or 'GoogLeNet' or 'Shufflenet_v2_x0_5' or 'Resnext101_32x8d')):\n",
  788. " ## Transformer for PyTorch-Network\n",
  789. " transformer = transforms.Compose([\n",
  790. " transforms.Resize(256),\n",
  791. " transforms.CenterCrop(224),\n",
  792. " transforms.ToTensor(),\n",
  793. " transforms.Normalize((0.4318, 0.4660, 0.5889), (0.1752, 0.1893, 0.2096)),\n",
  794. "])\n",
  795. "\n",
  796. "elif(network_name == ('Siamese_Network')):\n",
  797. " ## Transformer for Siamese-Network\n",
  798. " transformer=transforms.Compose([transforms.Resize((100,100)),\n",
  799. " transforms.ToTensor()\n",
  800. " ])\n",
  801. "else:\n",
  802. " print('False Network choosen')"
  803. ]
  804. },
  805. {
  806. "cell_type": "markdown",
  807. "metadata": {},
  808. "source": [
  809. "## Laden von Train- und Testdatensätzen mit Transformation"
  810. ]
  811. },
  812. {
  813. "cell_type": "code",
  814. "execution_count": 122,
  815. "metadata": {},
  816. "outputs": [],
  817. "source": [
  818. "## For PyTorch-Networks\n",
  819. "siamese_dataset_train = SiameseDataset_Ears(imageFolderDataset=folder_dataset_train,\n",
  820. " transform=transformer,\n",
  821. " should_invert=False)\n",
  822. "\n",
  823. "## Create Test-Dataset\n",
  824. "siamese_dataset_test = SiameseDataset_Ears(imageFolderDataset=folder_dataset_test,\n",
  825. " transform=transformer,\n",
  826. " should_invert=True)\n",
  827. "\n",
  828. "\n",
  829. "\n",
  830. "## Create Training-DataLoader\n",
  831. "train_dataloader = DataLoader(siamese_dataset_train, batch_size=batch_size_train, shuffle=True,)\n",
  832. "\n",
  833. "## Create Test-DataLoader\n",
  834. "test_dataloader = DataLoader(siamese_dataset_test, batch_size=batch_size_test, shuffle=True)"
  835. ]
  836. },
  837. {
  838. "cell_type": "markdown",
  839. "metadata": {},
  840. "source": [
  841. "## Visualisierung der Bildpaare des TrainLoaders\n",
  842. " \n",
  843. "Bildpaare: *Proband1 oben, Proband2 unten* \n",
  844. "0: Gleiches Paar \n",
  845. "1: Ungleiches Paar"
  846. ]
  847. },
  848. {
  849. "cell_type": "code",
  850. "execution_count": 123,
  851. "metadata": {},
  852. "outputs": [
  853. {
  854. "name": "stderr",
  855. "output_type": "stream",
  856. "text": [
  857. "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
  858. ]
  859. },
  860. {
  861. "data": {
  862. "image/png": "\n",
  863. "text/plain": [
  864. "<Figure size 432x288 with 1 Axes>"
  865. ]
  866. },
  867. "metadata": {
  868. "needs_background": "light"
  869. },
  870. "output_type": "display_data"
  871. },
  872. {
  873. "name": "stdout",
  874. "output_type": "stream",
  875. "text": [
  876. "[[0.]\n",
  877. " [1.]\n",
  878. " [0.]\n",
  879. " [1.]\n",
  880. " [1.]\n",
  881. " [0.]\n",
  882. " [1.]\n",
  883. " [1.]]\n"
  884. ]
  885. }
  886. ],
  887. "source": [
  888. "vis_dataloader = DataLoader(siamese_dataset_train, shuffle=True, batch_size=8)\n",
  889. "\n",
  890. "dataiter = iter(vis_dataloader)\n",
  891. "\n",
  892. "#example_batch = next(dataiter)\n",
  893. "example_batch = dataiter.next()\n",
  894. "concatenated = torch.cat((example_batch[0],example_batch[1]),0)\n",
  895. "imshow(torchvision.utils.make_grid(concatenated))\n",
  896. "print(example_batch[2].numpy())"
  897. ]
  898. },
  899. {
  900. "cell_type": "markdown",
  901. "metadata": {},
  902. "source": [
  903. "## Training:"
  904. ]
  905. },
  906. {
  907. "cell_type": "code",
  908. "execution_count": 124,
  909. "metadata": {
  910. "scrolled": true
  911. },
  912. "outputs": [
  913. {
  914. "name": "stdout",
  915. "output_type": "stream",
  916. "text": [
  917. "Start Training um: 21:20:32\n",
  918. "Epoch number 0\n",
  919. " Current loss 308.1065979003906\n",
  920. "\n",
  921. "Epoch number 0\n",
  922. " Current loss 1.138925552368164\n",
  923. "\n",
  924. "Epoch number 0\n",
  925. " Current loss 0.964200496673584\n",
  926. "\n",
  927. "Epoch number 1\n",
  928. " Current loss 1.052968978881836\n",
  929. "\n",
  930. "Epoch number 1\n",
  931. " Current loss 0.9594541788101196\n",
  932. "\n",
  933. "Epoch number 1\n",
  934. " Current loss 0.4107712507247925\n",
  935. "\n",
  936. "Epoch number 2\n",
  937. " Current loss 1.5126802921295166\n",
  938. "\n",
  939. "Epoch number 2\n",
  940. " Current loss 1.127558946609497\n",
  941. "\n",
  942. "Epoch number 2\n",
  943. " Current loss 3.7104902267456055\n",
  944. "\n",
  945. "Epoch number 3\n",
  946. " Current loss 0.7310527563095093\n",
  947. "\n",
  948. "Epoch number 3\n",
  949. " Current loss 0.7123371362686157\n",
  950. "\n",
  951. "Epoch number 3\n",
  952. " Current loss 1.3404734134674072\n",
  953. "\n",
  954. "Epoch number 4\n",
  955. " Current loss 1.0322986841201782\n",
  956. "\n",
  957. "Epoch number 4\n",
  958. " Current loss 1.4310665130615234\n",
  959. "\n",
  960. "Epoch number 4\n",
  961. " Current loss 0.8324761390686035\n",
  962. "\n",
  963. "Epoch number 5\n",
  964. " Current loss 0.8502787351608276\n",
  965. "\n",
  966. "Epoch number 5\n",
  967. " Current loss 1.3122851848602295\n",
  968. "\n",
  969. "Epoch number 5\n",
  970. " Current loss 0.921744167804718\n",
  971. "\n",
  972. "Epoch number 6\n",
  973. " Current loss 1.1442584991455078\n",
  974. "\n",
  975. "Epoch number 6\n",
  976. " Current loss 1.011681318283081\n",
  977. "\n",
  978. "Epoch number 6\n",
  979. " Current loss 0.9135341048240662\n",
  980. "\n",
  981. "Epoch number 7\n",
  982. " Current loss 1.2509589195251465\n",
  983. "\n",
  984. "Epoch number 7\n",
  985. " Current loss 0.7557617425918579\n",
  986. "\n",
  987. "Epoch number 7\n",
  988. " Current loss 0.8743131160736084\n",
  989. "\n",
  990. "Epoch number 8\n",
  991. " Current loss 0.7078773975372314\n",
  992. "\n",
  993. "Epoch number 8\n",
  994. " Current loss 2.043950080871582\n",
  995. "\n",
  996. "Epoch number 8\n",
  997. " Current loss 1.0420727729797363\n",
  998. "\n",
  999. "Epoch number 9\n",
  1000. " Current loss 0.9620883464813232\n",
  1001. "\n",
  1002. "Epoch number 9\n",
  1003. " Current loss 0.9625381827354431\n",
  1004. "\n",
  1005. "Epoch number 9\n",
  1006. " Current loss 1.1636301279067993\n",
  1007. "\n",
  1008. "Epoch number 10\n",
  1009. " Current loss 1.5617934465408325\n",
  1010. "\n",
  1011. "Epoch number 10\n",
  1012. " Current loss 1.029129147529602\n",
  1013. "\n",
  1014. "Epoch number 10\n",
  1015. " Current loss 1.2446420192718506\n",
  1016. "\n",
  1017. "Epoch number 11\n",
  1018. " Current loss 1.0907390117645264\n",
  1019. "\n",
  1020. "Epoch number 11\n",
  1021. " Current loss 0.9430091977119446\n",
  1022. "\n",
  1023. "Epoch number 11\n",
  1024. " Current loss 1.0374553203582764\n",
  1025. "\n",
  1026. "Epoch number 12\n",
  1027. " Current loss 1.2946057319641113\n",
  1028. "\n",
  1029. "Epoch number 12\n",
  1030. " Current loss 0.8857178092002869\n",
  1031. "\n",
  1032. "Epoch number 12\n",
  1033. " Current loss 0.9587574005126953\n",
  1034. "\n",
  1035. "Epoch number 13\n",
  1036. " Current loss 0.446734756231308\n",
  1037. "\n",
  1038. "Epoch number 13\n",
  1039. " Current loss 0.9940989017486572\n",
  1040. "\n",
  1041. "Epoch number 13\n",
  1042. " Current loss 0.9222986698150635\n",
  1043. "\n",
  1044. "Epoch number 14\n",
  1045. " Current loss 0.8071630597114563\n",
  1046. "\n",
  1047. "Epoch number 14\n",
  1048. " Current loss 0.9962530136108398\n",
  1049. "\n",
  1050. "Epoch number 14\n",
  1051. " Current loss 0.9935855269432068\n",
  1052. "\n",
  1053. "Epoch number 15\n",
  1054. " Current loss 1.2180464267730713\n",
  1055. "\n",
  1056. "Epoch number 15\n",
  1057. " Current loss 1.0259199142456055\n",
  1058. "\n",
  1059. "Epoch number 15\n",
  1060. " Current loss 0.8863251209259033\n",
  1061. "\n",
  1062. "Epoch number 16\n",
  1063. " Current loss 1.0609502792358398\n",
  1064. "\n",
  1065. "Epoch number 16\n",
  1066. " Current loss 0.9913638234138489\n",
  1067. "\n",
  1068. "Epoch number 16\n",
  1069. " Current loss 1.1235753297805786\n",
  1070. "\n",
  1071. "Epoch number 17\n",
  1072. " Current loss 1.1986947059631348\n",
  1073. "\n",
  1074. "Epoch number 17\n",
  1075. " Current loss 0.7384413480758667\n",
  1076. "\n",
  1077. "Epoch number 17\n",
  1078. " Current loss 0.8233901262283325\n",
  1079. "\n",
  1080. "Epoch number 18\n",
  1081. " Current loss 1.227494716644287\n",
  1082. "\n",
  1083. "Epoch number 18\n",
  1084. " Current loss 0.9361811876296997\n",
  1085. "\n",
  1086. "Epoch number 18\n",
  1087. " Current loss 0.19040855765342712\n",
  1088. "\n",
  1089. "Epoch number 19\n",
  1090. " Current loss 0.5063803195953369\n",
  1091. "\n",
  1092. "Epoch number 19\n",
  1093. " Current loss 1.4189283847808838\n",
  1094. "\n",
  1095. "Epoch number 19\n",
  1096. " Current loss 0.782526969909668\n",
  1097. "\n",
  1098. "Ende Training um: 21:21:26\n",
  1099. "Dauer Training: 00:53 [MM:SS] \n",
  1100. "\n"
  1101. ]
  1102. },
  1103. {
  1104. "data": {
  1105. "image/png": "\n",
  1106. "text/plain": [
  1107. "<Figure size 432x288 with 1 Axes>"
  1108. ]
  1109. },
  1110. "metadata": {
  1111. "needs_background": "light"
  1112. },
  1113. "output_type": "display_data"
  1114. }
  1115. ],
  1116. "source": [
  1117. "epoch_range = 20\n",
  1118. "counter = []\n",
  1119. "loss_history = [] \n",
  1120. "iteration_number= 0\n",
  1121. "\n",
  1122. "\n",
  1123. "## Start Training\n",
  1124. "print(\"Start Training um: \", time.strftime(\"%H:%M:%S\"))\n",
  1125. "start_time = time.time()\n",
  1126. "\n",
  1127. "for epoch in range(epoch_range):\n",
  1128. " for i, data in enumerate(train_dataloader, 0):\n",
  1129. " img0, img1 , label = data\n",
  1130. " img0, img1 , label = img0.to(DEVICE), img1.to(DEVICE), label.to(DEVICE)\n",
  1131. "\n",
  1132. " optimizer.zero_grad()\n",
  1133. " \n",
  1134. " ## Calculate Ouputs and LossConstrative\n",
  1135. " if(network_name == ('VGG11' or 'VGG11bn' or 'ResNet18' or 'ResNet34' or 'alexnet' or 'Squeezenet1_0' or 'GoogLeNet' or 'Shufflenet_v2_x0_5' or 'Resnext101_32x8d')):\n",
  1136. " output1 = network(img0)\n",
  1137. " output2 = network(img1)\n",
  1138. " loss_contrastive = criterion(output1,output2,label)\n",
  1139. " \n",
  1140. " elif(network_name == ('Siamese_Network')):\n",
  1141. " output1, output2 = network(img0, img1)\n",
  1142. " #outputs = network(img0, img1)\n",
  1143. " loss_contrastive = criterion(output1,output2,label)\n",
  1144. " else:\n",
  1145. " print('False Network choosen')\n",
  1146. " \n",
  1147. " \n",
  1148. " ## Backpropagation and Optmizer\n",
  1149. " loss_contrastive.backward()\n",
  1150. " optimizer.step()\n",
  1151. " \n",
  1152. " ## Get loss\n",
  1153. " if i %10 == 0 :\n",
  1154. " print(\"Epoch number {}\\n Current loss {}\\n\".format(epoch,loss_contrastive.item()))\n",
  1155. " iteration_number += 10\n",
  1156. " counter.append(iteration_number)\n",
  1157. " loss_history.append(loss_contrastive.item())\n",
  1158. "\n",
  1159. " \n",
  1160. "## Finished Training\n",
  1161. "print(\"Ende Training um: \",time.strftime(\"%H:%M:%S\"))\n",
  1162. "stop_time = time.time()\n",
  1163. "time_dif, time_format = secs_to_HMS(stop_time-start_time)\n",
  1164. "print(\"Dauer Training: \", time_dif, \" \", time_format, \" \\n\") \n",
  1165. "\n",
  1166. "show_plot(counter,loss_history)\n",
  1167. "WaitTime_Finished()"
  1168. ]
  1169. },
  1170. {
  1171. "cell_type": "code",
  1172. "execution_count": 63,
  1173. "metadata": {},
  1174. "outputs": [
  1175. {
  1176. "data": {
  1177. "image/png": "\n",
  1178. "text/plain": [
  1179. "<Figure size 432x288 with 1 Axes>"
  1180. ]
  1181. },
  1182. "metadata": {
  1183. "needs_background": "light"
  1184. },
  1185. "output_type": "display_data"
  1186. }
  1187. ],
  1188. "source": [
  1189. "show_plot(counter,loss_history)"
  1190. ]
  1191. },
  1192. {
  1193. "cell_type": "markdown",
  1194. "metadata": {},
  1195. "source": [
  1196. "## Speichern der trainierten Netzwerke"
  1197. ]
  1198. },
  1199. {
  1200. "cell_type": "code",
  1201. "execution_count": 125,
  1202. "metadata": {},
  1203. "outputs": [],
  1204. "source": [
  1205. "## Path to save Network \n",
  1206. "PATH = './Netzwerke_OneShot/' + current_time + '_' + Database + '_' + network_name + '_Train' + '.pth'\n",
  1207. "\n",
  1208. "## Save Network\n",
  1209. "torch.save({\n",
  1210. " 'epoch_range': epoch_range,\n",
  1211. " 'model_state_dict': network.state_dict(),\n",
  1212. " 'optimizer_state_dict': optimizer.state_dict(),\n",
  1213. " 'loss': loss_contrastive\n",
  1214. " }, PATH)"
  1215. ]
  1216. },
  1217. {
  1218. "cell_type": "markdown",
  1219. "metadata": {},
  1220. "source": [
  1221. "## Testen"
  1222. ]
  1223. },
  1224. {
  1225. "cell_type": "code",
  1226. "execution_count": 126,
  1227. "metadata": {},
  1228. "outputs": [],
  1229. "source": [
  1230. "folder_dataset_test = dset.ImageFolder(root=Config.path_OneShot_Test)\n",
  1231. "\n",
  1232. "siamese_dataset = SiameseDataset_Ears(imageFolderDataset=folder_dataset_test,\n",
  1233. " transform=transformer\n",
  1234. " ,should_invert=False)\n",
  1235. "\n",
  1236. "test_dataloader = DataLoader(siamese_dataset,batch_size=1,shuffle=True)\n",
  1237. "\n",
  1238. "## Test Function for Siamese-Network\n",
  1239. "def evaluation(model, test_loader):\n",
  1240. " with torch.no_grad():\n",
  1241. " model.eval()\n",
  1242. " correct = 0\n",
  1243. " count = 0\n",
  1244. "\n",
  1245. " for mainImg, imgSets, label in test_loader:\n",
  1246. " mainImg, imgSets, label = mainImg.to(DEVICE), imgSets.to(DEVICE), label.to(DEVICE)\n",
  1247. " predVal = 2.1\n",
  1248. " pred = -1\n",
  1249. " count += 1\n",
  1250. " \n",
  1251. " ## Determine which category an image belongs to\n",
  1252. " for i, testImg in enumerate(imgSets):\n",
  1253. " testImg = testImg.to(DEVICE)\n",
  1254. "\n",
  1255. " if(network_name == ('VGG11' or 'VGG11bn' or 'ResNet18' or 'ResNet34' or 'alexnet' or 'Squeezenet1_0' or 'GoogLeNet' or 'Shufflenet_v2_x0_5' or 'Resnext101_32x8d')):\n",
  1256. " output1 = network(mainImg)\n",
  1257. " output2 = network(imgSets)\n",
  1258. " \n",
  1259. " elif(network_name == ('Siamese_Network')):\n",
  1260. " output1, output2 = network(mainImg, imgSets)\n",
  1261. " \n",
  1262. " else:\n",
  1263. " print('False Network choosen')\n",
  1264. "\n",
  1265. " \n",
  1266. " euclidean_distance = F.pairwise_distance(output1, output2)\n",
  1267. " euclidean_distance = euclidean_distance.cpu().numpy()\n",
  1268. " \n",
  1269. "\n",
  1270. " if(((euclidean_distance < predVal) and (label==0)) or ((euclidean_distance > predVal) and (label==1))):\n",
  1271. " correct += 1\n",
  1272. " print('Accuracy: {}'.format(correct/count))"
  1273. ]
  1274. },
  1275. {
  1276. "cell_type": "code",
  1277. "execution_count": 127,
  1278. "metadata": {},
  1279. "outputs": [
  1280. {
  1281. "name": "stdout",
  1282. "output_type": "stream",
  1283. "text": [
  1284. "Accuracy: 1.0\n",
  1285. "Accuracy: 0.5\n",
  1286. "Accuracy: 0.3333333333333333\n",
  1287. "Accuracy: 0.25\n",
  1288. "Accuracy: 0.4\n",
  1289. "Accuracy: 0.3333333333333333\n",
  1290. "Accuracy: 0.2857142857142857\n",
  1291. "Accuracy: 0.375\n",
  1292. "Accuracy: 0.3333333333333333\n",
  1293. "Accuracy: 0.3\n",
  1294. "Accuracy: 0.2727272727272727\n",
  1295. "Accuracy: 0.25\n",
  1296. "Accuracy: 0.3076923076923077\n",
  1297. "Accuracy: 0.35714285714285715\n",
  1298. "Accuracy: 0.4\n",
  1299. "Accuracy: 0.375\n",
  1300. "Accuracy: 0.35294117647058826\n",
  1301. "Accuracy: 0.3888888888888889\n",
  1302. "Accuracy: 0.42105263157894735\n",
  1303. "Accuracy: 0.45\n",
  1304. "Accuracy: 0.42857142857142855\n",
  1305. "Accuracy: 0.4090909090909091\n",
  1306. "Accuracy: 0.391304347826087\n",
  1307. "Accuracy: 0.375\n",
  1308. "Accuracy: 0.4\n",
  1309. "Accuracy: 0.38461538461538464\n",
  1310. "Accuracy: 0.37037037037037035\n",
  1311. "Accuracy: 0.39285714285714285\n",
  1312. "Accuracy: 0.3793103448275862\n",
  1313. "Accuracy: 0.36666666666666664\n",
  1314. "Accuracy: 0.3870967741935484\n",
  1315. "Accuracy: 0.375\n",
  1316. "Accuracy: 0.3939393939393939\n",
  1317. "Accuracy: 0.4117647058823529\n",
  1318. "Accuracy: 0.42857142857142855\n",
  1319. "Accuracy: 0.4166666666666667\n",
  1320. "Accuracy: 0.40540540540540543\n",
  1321. "Accuracy: 0.42105263157894735\n",
  1322. "Accuracy: 0.4358974358974359\n",
  1323. "Accuracy: 0.45\n"
  1324. ]
  1325. }
  1326. ],
  1327. "source": [
  1328. "evaluation(network, test_dataloader)"
  1329. ]
  1330. },
  1331. {
  1332. "cell_type": "code",
  1333. "execution_count": null,
  1334. "metadata": {},
  1335. "outputs": [],
  1336. "source": []
  1337. }
  1338. ],
  1339. "metadata": {
  1340. "kernelspec": {
  1341. "display_name": "Python 3",
  1342. "language": "python",
  1343. "name": "python3"
  1344. },
  1345. "language_info": {
  1346. "codemirror_mode": {
  1347. "name": "ipython",
  1348. "version": 3
  1349. },
  1350. "file_extension": ".py",
  1351. "mimetype": "text/x-python",
  1352. "name": "python",
  1353. "nbconvert_exporter": "python",
  1354. "pygments_lexer": "ipython3",
  1355. "version": "3.8.3"
  1356. }
  1357. },
  1358. "nbformat": 4,
  1359. "nbformat_minor": 4
  1360. }