# Code adapted from: https://github.com/Cadene/pretrained-models.pytorch import math from collections import OrderedDict from itertools import chain import torch.nn as nn from torch.utils import model_zoo from utils import Flatten class SEModule(nn.Module): def __init__(self, channels, reduction): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0) self.relu = nn.ReLU(inplace=True) self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0) self.sigmoid = nn.Sigmoid() def forward(self, x): module_input = x x = self.avg_pool(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) x = self.sigmoid(x) return module_input * x class SEResNeXtBottleneck(nn.Module): """ ResNeXt bottleneck type C with a Squeeze-and-Excitation module. """ expansion = 4 def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None, base_width=4): super(SEResNeXtBottleneck, self).__init__() width = math.floor(planes * (base_width / 64)) * groups self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1) self.bn1 = nn.BatchNorm2d(width) self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) self.bn2 = nn.BatchNorm2d(width) self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.se_module = SEModule(planes * 4, reduction=reduction) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out = self.se_module(out) + residual out = self.relu(out) return out class SENet(nn.Module): def __init__(self, block, layers, groups, reduction, inplanes=128, downsample_kernel_size=3, downsample_padding=1): super(SENet, self).__init__() self.inplanes = inplanes layer0_modules = [ ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), ('bn1', nn.BatchNorm2d(inplanes)), ('relu1', nn.ReLU(inplace=True)), # To preserve compatibility with Caffe weights `ceil_mode=True` # is used instead of `padding=1`. ('pool', nn.MaxPool2d(3, stride=2, ceil_mode=True)) ] self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) self.layer1 = self._make_layer( block, planes=64, blocks=layers[0], groups=groups, reduction=reduction, downsample_kernel_size=1, downsample_padding=0 ) self.layer2 = self._make_layer( block, planes=128, blocks=layers[1], stride=2, groups=groups, reduction=reduction, downsample_kernel_size=downsample_kernel_size, downsample_padding=downsample_padding ) self.layer3 = self._make_layer( block, planes=256, blocks=layers[2], stride=2, groups=groups, reduction=reduction, downsample_kernel_size=downsample_kernel_size, downsample_padding=downsample_padding ) self.layer4 = self._make_layer( block, planes=512, blocks=layers[3], stride=2, groups=groups, reduction=reduction, downsample_kernel_size=downsample_kernel_size, downsample_padding=downsample_padding ) self.cls = nn.Sequential( nn.AdaptiveAvgPool2d(1), Flatten(), nn.Linear(512 * block.expansion, 1) ) def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, downsample_kernel_size=1, downsample_padding=0): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=downsample_kernel_size, stride=stride, padding=downsample_padding, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [block(self.inplanes, planes, groups, reduction, stride, downsample)] self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, groups, reduction)) return nn.Sequential(*layers) def paramgroup01(self): return chain( self.layer0.parameters(), self.layer1.parameters(), ) def paramgroup234(self): return chain( self.layer2.parameters(), self.layer3.parameters(), self.layer4.parameters(), ) def parameters_classifier(self): return self.cls.parameters() def forward(self, x): x = self.layer0(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) c = self.cls(x) return c def get_model(): model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, inplanes=64, downsample_kernel_size=1, downsample_padding=0) checkpoint = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth') model.load_state_dict(checkpoint, strict=False) return model