In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

resnet.py 6.5KB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. '''ResNet in PyTorch.
  2. BasicBlock and Bottleneck module is from the original ResNet paper:
  3. [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
  4. Deep Residual Learning for Image Recognition. arXiv:1512.03385
  5. PreActBlock and PreActBottleneck module is from the later paper:
  6. [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
  7. Identity Mappings in Deep Residual Networks. arXiv:1603.05027
  8. '''
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. from models.base_model import BaseModel
  13. from models.transform_layers import NormalizeLayer
  14. from torch.nn.utils import spectral_norm
  15. def conv3x3(in_planes, out_planes, stride=1):
  16. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
  17. class BasicBlock(nn.Module):
  18. expansion = 1
  19. def __init__(self, in_planes, planes, stride=1):
  20. super(BasicBlock, self).__init__()
  21. self.conv1 = conv3x3(in_planes, planes, stride)
  22. self.conv2 = conv3x3(planes, planes)
  23. self.bn1 = nn.BatchNorm2d(planes)
  24. self.bn2 = nn.BatchNorm2d(planes)
  25. self.shortcut = nn.Sequential()
  26. if stride != 1 or in_planes != self.expansion*planes:
  27. self.shortcut = nn.Sequential(
  28. nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
  29. nn.BatchNorm2d(self.expansion*planes)
  30. )
  31. def forward(self, x):
  32. out = F.relu(self.bn1(self.conv1(x)))
  33. out = self.bn2(self.conv2(out))
  34. out += self.shortcut(x)
  35. out = F.relu(out)
  36. return out
  37. class PreActBlock(nn.Module):
  38. '''Pre-activation version of the BasicBlock.'''
  39. expansion = 1
  40. def __init__(self, in_planes, planes, stride=1):
  41. super(PreActBlock, self).__init__()
  42. self.conv1 = conv3x3(in_planes, planes, stride)
  43. self.conv2 = conv3x3(planes, planes)
  44. self.bn1 = nn.BatchNorm2d(in_planes)
  45. self.bn2 = nn.BatchNorm2d(planes)
  46. self.shortcut = nn.Sequential()
  47. if stride != 1 or in_planes != self.expansion*planes:
  48. self.shortcut = nn.Sequential(
  49. nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
  50. )
  51. def forward(self, x):
  52. out = F.relu(self.bn1(x))
  53. shortcut = self.shortcut(out)
  54. out = self.conv1(out)
  55. out = self.conv2(F.relu(self.bn2(out)))
  56. out += shortcut
  57. return out
  58. class Bottleneck(nn.Module):
  59. expansion = 4
  60. def __init__(self, in_planes, planes, stride=1):
  61. super(Bottleneck, self).__init__()
  62. self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
  63. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
  64. self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
  65. self.bn1 = nn.BatchNorm2d(planes)
  66. self.bn2 = nn.BatchNorm2d(planes)
  67. self.bn3 = nn.BatchNorm2d(self.expansion * planes)
  68. self.shortcut = nn.Sequential()
  69. if stride != 1 or in_planes != self.expansion*planes:
  70. self.shortcut = nn.Sequential(
  71. nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
  72. nn.BatchNorm2d(self.expansion*planes)
  73. )
  74. def forward(self, x):
  75. out = F.relu(self.bn1(self.conv1(x)))
  76. out = F.relu(self.bn2(self.conv2(out)))
  77. out = self.bn3(self.conv3(out))
  78. out += self.shortcut(x)
  79. out = F.relu(out)
  80. return out
  81. class PreActBottleneck(nn.Module):
  82. '''Pre-activation version of the original Bottleneck module.'''
  83. expansion = 4
  84. def __init__(self, in_planes, planes, stride=1):
  85. super(PreActBottleneck, self).__init__()
  86. self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
  87. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
  88. self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
  89. self.bn1 = nn.BatchNorm2d(in_planes)
  90. self.bn2 = nn.BatchNorm2d(planes)
  91. self.bn3 = nn.BatchNorm2d(planes)
  92. self.shortcut = nn.Sequential()
  93. if stride != 1 or in_planes != self.expansion*planes:
  94. self.shortcut = nn.Sequential(
  95. nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
  96. )
  97. def forward(self, x):
  98. out = F.relu(self.bn1(x))
  99. shortcut = self.shortcut(out)
  100. out = self.conv1(out)
  101. out = self.conv2(F.relu(self.bn2(out)))
  102. out = self.conv3(F.relu(self.bn3(out)))
  103. out += shortcut
  104. return out
  105. class ResNet(BaseModel):
  106. def __init__(self, block, num_blocks, num_classes=10):
  107. last_dim = 512 * block.expansion
  108. super(ResNet, self).__init__(last_dim, num_classes)
  109. self.in_planes = 64
  110. self.last_dim = last_dim
  111. self.normalize = NormalizeLayer()
  112. self.conv1 = conv3x3(3, 64)
  113. self.bn1 = nn.BatchNorm2d(64)
  114. self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
  115. self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
  116. self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
  117. self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
  118. def _make_layer(self, block, planes, num_blocks, stride):
  119. strides = [stride] + [1]*(num_blocks-1)
  120. layers = []
  121. for stride in strides:
  122. layers.append(block(self.in_planes, planes, stride))
  123. self.in_planes = planes * block.expansion
  124. return nn.Sequential(*layers)
  125. def penultimate(self, x, all_features=False):
  126. out_list = []
  127. out = self.normalize(x)
  128. out = self.conv1(out)
  129. out = self.bn1(out)
  130. out = F.relu(out)
  131. out_list.append(out)
  132. out = self.layer1(out)
  133. out_list.append(out)
  134. out = self.layer2(out)
  135. out_list.append(out)
  136. out = self.layer3(out)
  137. out_list.append(out)
  138. out = self.layer4(out)
  139. out_list.append(out)
  140. out = F.avg_pool2d(out, 4)
  141. out = out.view(out.size(0), -1)
  142. if all_features:
  143. return out, out_list
  144. else:
  145. return out
  146. def ResNet18(num_classes):
  147. return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes)
  148. def ResNet34(num_classes):
  149. return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes)
  150. def ResNet50(num_classes):
  151. return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes)