In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

resnet_imagenet.py 8.4KB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import torch
  2. import torch.nn as nn
  3. from models.base_model import BaseModel
  4. from models.transform_layers import NormalizeLayer
  5. def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
  6. """3x3 convolution with padding"""
  7. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  8. padding=dilation, groups=groups, bias=False, dilation=dilation)
  9. def conv1x1(in_planes, out_planes, stride=1):
  10. """1x1 convolution"""
  11. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  12. class BasicBlock(nn.Module):
  13. expansion = 1
  14. def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
  15. base_width=64, dilation=1, norm_layer=None):
  16. super(BasicBlock, self).__init__()
  17. if norm_layer is None:
  18. norm_layer = nn.BatchNorm2d
  19. if groups != 1 or base_width != 64:
  20. raise ValueError('BasicBlock only supports groups=1 and base_width=64')
  21. if dilation > 1:
  22. raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
  23. # Both self.conv1 and self.downsample layers downsample the input when stride != 1
  24. self.conv1 = conv3x3(inplanes, planes, stride)
  25. self.bn1 = norm_layer(planes)
  26. self.relu = nn.ReLU(inplace=True)
  27. self.conv2 = conv3x3(planes, planes)
  28. self.bn2 = norm_layer(planes)
  29. self.downsample = downsample
  30. self.stride = stride
  31. def forward(self, x):
  32. identity = x
  33. out = self.conv1(x)
  34. out = self.bn1(out)
  35. out = self.relu(out)
  36. out = self.conv2(out)
  37. out = self.bn2(out)
  38. if self.downsample is not None:
  39. identity = self.downsample(x)
  40. out += identity
  41. out = self.relu(out)
  42. return out
  43. class Bottleneck(nn.Module):
  44. # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
  45. # while original implementation places the stride at the first 1x1 convolution(self.conv1)
  46. # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
  47. # This variant is also known as ResNet V1.5 and improves accuracy according to
  48. # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
  49. expansion = 4
  50. def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
  51. base_width=64, dilation=1, norm_layer=None):
  52. super(Bottleneck, self).__init__()
  53. if norm_layer is None:
  54. norm_layer = nn.BatchNorm2d
  55. width = int(planes * (base_width / 64.)) * groups
  56. # Both self.conv2 and self.downsample layers downsample the input when stride != 1
  57. self.conv1 = conv1x1(inplanes, width)
  58. self.bn1 = norm_layer(width)
  59. self.conv2 = conv3x3(width, width, stride, groups, dilation)
  60. self.bn2 = norm_layer(width)
  61. self.conv3 = conv1x1(width, planes * self.expansion)
  62. self.bn3 = norm_layer(planes * self.expansion)
  63. self.relu = nn.ReLU(inplace=True)
  64. self.downsample = downsample
  65. self.stride = stride
  66. def forward(self, x):
  67. identity = x
  68. out = self.conv1(x)
  69. out = self.bn1(out)
  70. out = self.relu(out)
  71. out = self.conv2(out)
  72. out = self.bn2(out)
  73. out = self.relu(out)
  74. out = self.conv3(out)
  75. out = self.bn3(out)
  76. if self.downsample is not None:
  77. identity = self.downsample(x)
  78. out += identity
  79. out = self.relu(out)
  80. return out
  81. class ResNet(BaseModel):
  82. def __init__(self, block, layers, num_classes=10,
  83. zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None,
  84. norm_layer=None):
  85. last_dim = 512 * block.expansion
  86. super(ResNet, self).__init__(last_dim, num_classes)
  87. if norm_layer is None:
  88. norm_layer = nn.BatchNorm2d
  89. self._norm_layer = norm_layer
  90. self.inplanes = 64
  91. self.dilation = 1
  92. if replace_stride_with_dilation is None:
  93. # each element in the tuple indicates if we should replace
  94. # the 2x2 stride with a dilated convolution instead
  95. replace_stride_with_dilation = [False, False, False]
  96. if len(replace_stride_with_dilation) != 3:
  97. raise ValueError("replace_stride_with_dilation should be None "
  98. "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
  99. self.groups = groups
  100. self.base_width = width_per_group
  101. self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
  102. bias=False)
  103. self.bn1 = norm_layer(self.inplanes)
  104. self.relu = nn.ReLU(inplace=True)
  105. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  106. self.layer1 = self._make_layer(block, 64, layers[0])
  107. self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
  108. dilate=replace_stride_with_dilation[0])
  109. self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
  110. dilate=replace_stride_with_dilation[1])
  111. self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
  112. dilate=replace_stride_with_dilation[2])
  113. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  114. self.normalize = NormalizeLayer()
  115. self.last_dim = 512 * block.expansion
  116. for m in self.modules():
  117. if isinstance(m, nn.Conv2d):
  118. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  119. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
  120. nn.init.constant_(m.weight, 1)
  121. nn.init.constant_(m.bias, 0)
  122. # Zero-initialize the last BN in each residual branch,
  123. # so that the residual branch starts with zeros, and each residual block behaves like an identity.
  124. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
  125. if zero_init_residual:
  126. for m in self.modules():
  127. if isinstance(m, Bottleneck):
  128. nn.init.constant_(m.bn3.weight, 0)
  129. elif isinstance(m, BasicBlock):
  130. nn.init.constant_(m.bn2.weight, 0)
  131. def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
  132. norm_layer = self._norm_layer
  133. downsample = None
  134. previous_dilation = self.dilation
  135. if dilate:
  136. self.dilation *= stride
  137. stride = 1
  138. if stride != 1 or self.inplanes != planes * block.expansion:
  139. downsample = nn.Sequential(
  140. conv1x1(self.inplanes, planes * block.expansion, stride),
  141. norm_layer(planes * block.expansion),
  142. )
  143. layers = []
  144. layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
  145. self.base_width, previous_dilation, norm_layer))
  146. self.inplanes = planes * block.expansion
  147. for _ in range(1, blocks):
  148. layers.append(block(self.inplanes, planes, groups=self.groups,
  149. base_width=self.base_width, dilation=self.dilation,
  150. norm_layer=norm_layer))
  151. return nn.Sequential(*layers)
  152. def penultimate(self, x, all_features=False):
  153. # See note [TorchScript super()]
  154. out_list = []
  155. x = self.normalize(x)
  156. x = self.conv1(x)
  157. x = self.bn1(x)
  158. x = self.relu(x)
  159. x = self.maxpool(x)
  160. out_list.append(x)
  161. x = self.layer1(x)
  162. out_list.append(x)
  163. x = self.layer2(x)
  164. out_list.append(x)
  165. x = self.layer3(x)
  166. out_list.append(x)
  167. x = self.layer4(x)
  168. out_list.append(x)
  169. x = self.avgpool(x)
  170. x = torch.flatten(x, 1)
  171. if all_features:
  172. return x, out_list
  173. else:
  174. return x
  175. def _resnet(arch, block, layers, **kwargs):
  176. model = ResNet(block, layers, **kwargs)
  177. return model
  178. def resnet18(**kwargs):
  179. r"""ResNet-18 model from
  180. `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
  181. """
  182. return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], **kwargs)
  183. def resnet50(**kwargs):
  184. r"""ResNet-50 model from
  185. `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
  186. """
  187. return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], **kwargs)