In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
  1. import math
  2. import numbers
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch.autograd import Function
  8. from torchvision import transforms
  9. if torch.__version__ >= '1.4.0':
  10. kwargs = {'align_corners': False}
  11. else:
  12. kwargs = {}
  13. def rgb2hsv(rgb):
  14. """Convert a 4-d RGB tensor to the HSV counterpart.
  15. Here, we compute hue using atan2() based on the definition in [1],
  16. instead of using the common lookup table approach as in [2, 3].
  17. Those values agree when the angle is a multiple of 30°,
  18. otherwise they may differ at most ~1.2°.
  19. References
  20. [1]
  21. [2]
  22. [3]
  23. """
  24. r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :]
  25. Cmax = rgb.max(1)[0]
  26. Cmin = rgb.min(1)[0]
  27. delta = Cmax - Cmin
  28. hue = torch.atan2(math.sqrt(3) * (g - b), 2 * r - g - b)
  29. hue = (hue % (2 * math.pi)) / (2 * math.pi)
  30. saturate = delta / Cmax
  31. value = Cmax
  32. hsv = torch.stack([hue, saturate, value], dim=1)
  33. hsv[~torch.isfinite(hsv)] = 0.
  34. return hsv
  35. def hsv2rgb(hsv):
  36. """Convert a 4-d HSV tensor to the RGB counterpart.
  37. >>> %timeit hsv2rgb(hsv)
  38. 2.37 ms ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
  39. >>> %timeit rgb2hsv_fast(rgb)
  40. 298 µs ± 542 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
  41. >>> torch.allclose(hsv2rgb(hsv), hsv2rgb_fast(hsv), atol=1e-6)
  42. True
  43. References
  44. [1]
  45. """
  46. h, s, v = hsv[:, [0]], hsv[:, [1]], hsv[:, [2]]
  47. c = v * s
  48. n = hsv.new_tensor([5, 3, 1]).view(3, 1, 1)
  49. k = (n + h * 6) % 6
  50. t = torch.min(k, 4 - k)
  51. t = torch.clamp(t, 0, 1)
  52. return v - c * t
  53. class RandomResizedCropLayer(nn.Module):
  54. def __init__(self, size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.)):
  55. '''
  56. Inception Crop
  57. size (tuple): size of fowarding image (C, W, H)
  58. scale (tuple): range of size of the origin size cropped
  59. ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
  60. '''
  61. super(RandomResizedCropLayer, self).__init__()
  62. _eye = torch.eye(2, 3)
  63. self.size = size
  64. self.register_buffer('_eye', _eye)
  65. self.scale = scale
  66. self.ratio = ratio
  67. def forward(self, inputs, whbias=None):
  68. _device = inputs.device
  69. N = inputs.size(0)
  70. _theta = self._eye.repeat(N, 1, 1)
  71. if whbias is None:
  72. whbias = self._sample_latent(inputs)
  73. _theta[:, 0, 0] = whbias[:, 0]
  74. _theta[:, 1, 1] = whbias[:, 1]
  75. _theta[:, 0, 2] = whbias[:, 2]
  76. _theta[:, 1, 2] = whbias[:, 3]
  77. grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device)
  78. output = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs)
  79. if self.size is not None:
  80. output = F.adaptive_avg_pool2d(output, self.size)
  81. # output = F.adaptive_avg_pool2d(output, self.size)
  82. # output = F.adaptive_avg_pool2d(output, (self.size[0], self.size[1]))
  83. return output
  84. def _clamp(self, whbias):
  85. w = whbias[:, 0]
  86. h = whbias[:, 1]
  87. w_bias = whbias[:, 2]
  88. h_bias = whbias[:, 3]
  89. # Clamp with scale
  90. w = torch.clamp(w, *self.scale)
  91. h = torch.clamp(h, *self.scale)
  92. # Clamp with ratio
  93. w = self.ratio[0] * h + torch.relu(w - self.ratio[0] * h)
  94. w = self.ratio[1] * h - torch.relu(self.ratio[1] * h - w)
  95. # Clamp with bias range: w_bias \in (w - 1, 1 - w), h_bias \in (h - 1, 1 - h)
  96. w_bias = w - 1 + torch.relu(w_bias - w + 1)
  97. w_bias = 1 - w - torch.relu(1 - w - w_bias)
  98. h_bias = h - 1 + torch.relu(h_bias - h + 1)
  99. h_bias = 1 - h - torch.relu(1 - h - h_bias)
  100. whbias = torch.stack([w, h, w_bias, h_bias], dim=0).t()
  101. return whbias
  102. def _sample_latent(self, inputs):
  103. _device = inputs.device
  104. N, _, width, height = inputs.shape
  105. # N * 10 trial
  106. area = width * height
  107. target_area = np.random.uniform(*self.scale, N * 10) * area
  108. log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
  109. aspect_ratio = np.exp(np.random.uniform(*log_ratio, N * 10))
  110. # If doesn't satisfy ratio condition, then do central crop
  111. w = np.round(np.sqrt(target_area * aspect_ratio))
  112. h = np.round(np.sqrt(target_area / aspect_ratio))
  113. cond = (0 < w) * (w <= width) * (0 < h) * (h <= height)
  114. w = w[cond]
  115. h = h[cond]
  116. cond_len = w.shape[0]
  117. if cond_len >= N:
  118. w = w[:N]
  119. h = h[:N]
  120. else:
  121. w = np.concatenate([w, np.ones(N - cond_len) * width])
  122. h = np.concatenate([h, np.ones(N - cond_len) * height])
  123. w_bias = np.random.randint(w - width, width - w + 1) / width
  124. h_bias = np.random.randint(h - height, height - h + 1) / height
  125. w = w / width
  126. h = h / height
  127. whbias = np.column_stack([w, h, w_bias, h_bias])
  128. whbias = torch.tensor(whbias, device=_device)
  129. return whbias
  130. class HorizontalFlipRandomCrop(nn.Module):
  131. def __init__(self, max_range):
  132. super(HorizontalFlipRandomCrop, self).__init__()
  133. self.max_range = max_range
  134. _eye = torch.eye(2, 3)
  135. self.register_buffer('_eye', _eye)
  136. def forward(self, input, sign=None, bias=None, rotation=None):
  137. _device = input.device
  138. N = input.size(0)
  139. _theta = self._eye.repeat(N, 1, 1)
  140. if sign is None:
  141. sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1
  142. if bias is None:
  143. bias = torch.empty((N, 2), device=_device).uniform_(-self.max_range, self.max_range)
  144. _theta[:, 0, 0] = sign
  145. _theta[:, :, 2] = bias
  146. if rotation is not None:
  147. _theta[:, 0:2, 0:2] = rotation
  148. grid = F.affine_grid(_theta, input.size(), **kwargs).to(_device)
  149. output = F.grid_sample(input, grid, padding_mode='reflection', **kwargs)
  150. return output
  151. def _sample_latent(self, N, device=None):
  152. sign = torch.bernoulli(torch.ones(N, device=device) * 0.5) * 2 - 1
  153. bias = torch.empty((N, 2), device=device).uniform_(-self.max_range, self.max_range)
  154. return sign, bias
  155. class Rotation(nn.Module):
  156. def __init__(self, max_range = 4):
  157. super(Rotation, self).__init__()
  158. self.max_range = max_range
  159. self.prob = 0.5
  160. def forward(self, input, aug_index=None):
  161. _device = input.device
  162. _, _, H, W = input.size()
  163. if aug_index is None:
  164. aug_index = np.random.randint(4)
  165. output = torch.rot90(input, aug_index, (2, 3))
  166. _prob = input.new_full((input.size(0),), self.prob)
  167. _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
  168. output = _mask * input + (1-_mask) * output
  169. else:
  170. aug_index = aug_index % self.max_range
  171. output = torch.rot90(input, aug_index, (2, 3))
  172. return output
  173. class RandomAdjustSharpness(nn.Module):
  174. def __init__(self, sharpness_factor=0.5, p=0.5):
  175. super(RandomAdjustSharpness, self).__init__()
  176. self.sharpness_factor = sharpness_factor
  177. self.prob = p
  178. def forward(self, input, aug_index=None):
  179. _device = input.device
  180. _, _, H, W = input.size()
  181. if aug_index == 0:
  182. output = input
  183. else:
  184. output = transforms.RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.prob)(input)
  185. return output
  186. class RandPers(nn.Module):
  187. def __init__(self, distortion_scale=0.5, p=0.5):
  188. super(RandPers, self).__init__()
  189. self.distortion_scale = distortion_scale
  190. self.prob = p
  191. def forward(self, input, aug_index=None):
  192. _device = input.device
  193. _, _, H, W = input.size()
  194. if aug_index == 0:
  195. output = input
  196. else:
  197. output = transforms.RandomPerspective(distortion_scale=self.distortion_scale, p=self.prob)(input)
  198. return output
  199. class GaussBlur(nn.Module):
  200. def __init__(self, max_range = 4, kernel_size=3, sigma=(0.1, 2.0)):
  201. super(GaussBlur, self).__init__()
  202. self.max_range = max_range
  203. self.prob = 0.5
  204. self.sigma = sigma
  205. self.kernel_size = kernel_size
  206. def forward(self, input, aug_index=None):
  207. _device = input.device
  208. _, _, H, W = input.size()
  209. if aug_index is None:
  210. aug_index = np.random.randint(4)
  211. output = transforms.GaussianBlur(kernel_size=13, sigma=abs(aug_index)+1)(input)
  212. _prob = input.new_full((input.size(0),), self.prob)
  213. _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
  214. output = _mask * input + (1-_mask) * output
  215. else:
  216. if aug_index == 0:
  217. output = input
  218. else:
  219. output = transforms.GaussianBlur(kernel_size=self.kernel_size, sigma=self.sigma)(input)
  220. return output
  221. class GaussNoise(nn.Module):
  222. def __init__(self, mean = 0, std = 1):
  223. super(GaussNoise, self).__init__()
  224. self.mean = mean
  225. self.std = std
  226. def forward(self, input, aug_index=None):
  227. _device = input.device
  228. _, _, H, W = input.size()
  229. if aug_index == 0:
  230. output = input
  231. else:
  232. output = input + (torch.randn(input.size()) * self.std + self.mean).to(_device)
  233. return output
  234. class BlurRandpers(nn.Module):
  235. def __init__(self, max_range=2, kernel_size=3, sigma=(10, 20), distortion_scale=0.6, p=1):
  236. super(BlurRandpers, self).__init__()
  237. self.max_range = max_range
  238. self.sigma = sigma
  239. self.kernel_size = kernel_size
  240. self.distortion_scale = distortion_scale
  241. self.p = p
  242. self.gauss = GaussBlur(kernel_size=self.kernel_size, sigma=self.sigma)
  243. self.randpers = RandPers(distortion_scale=self.distortion_scale, p=self.p)
  244. def forward(self, input, aug_index=None):
  245. output = self.gauss.forward(input=input, aug_index=aug_index)
  246. output = self.randpers.forward(input=output, aug_index=aug_index)
  247. return output
  248. class BlurSharpness(nn.Module):
  249. def __init__(self, max_range=2, kernel_size=3, sigma=(10, 20), sharpness_factor=0.6, p=1):
  250. super(BlurSharpness, self).__init__()
  251. self.max_range = max_range
  252. self.sigma = sigma
  253. self.kernel_size = kernel_size
  254. self.sharpness_factor = sharpness_factor
  255. self.p = p
  256. self.gauss = GaussBlur(kernel_size=self.kernel_size, sigma=self.sigma)
  257. = RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.p)
  258. def forward(self, input, aug_index=None):
  259. output = self.gauss.forward(input=input, aug_index=aug_index)
  260. output =, aug_index=aug_index)
  261. return output
  262. class RandpersSharpness(nn.Module):
  263. def __init__(self, max_range=2, distortion_scale=0.6, p=1, sharpness_factor=0.6):
  264. super(RandpersSharpness, self).__init__()
  265. self.max_range = max_range
  266. self.distortion_scale = distortion_scale
  267. self.p = p
  268. self.sharpness_factor = sharpness_factor
  269. self.randpers = RandPers(distortion_scale=self.distortion_scale, p=self.p)
  270. = RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.p)
  271. def forward(self, input, aug_index=None):
  272. output = self.randpers.forward(input=input, aug_index=aug_index)
  273. output =, aug_index=aug_index)
  274. return output
  275. class BlurRandpersSharpness(nn.Module):
  276. def __init__(self, max_range=2, kernel_size=3, sigma=(10, 20), distortion_scale=0.6, p=1, sharpness_factor=0.6):
  277. super(BlurRandpersSharpness, self).__init__()
  278. self.max_range = max_range
  279. self.sigma = sigma
  280. self.kernel_size = kernel_size
  281. self.distortion_scale = distortion_scale
  282. self.p = p
  283. self.sharpness_factor = sharpness_factor
  284. self.gauss = GaussBlur(kernel_size=self.kernel_size, sigma=self.sigma)
  285. self.randpers = RandPers(distortion_scale=self.distortion_scale, p=self.p)
  286. = RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.p)
  287. def forward(self, input, aug_index=None):
  288. output = self.gauss.forward(input=input, aug_index=aug_index)
  289. output = self.randpers.forward(input=output, aug_index=aug_index)
  290. output =, aug_index=aug_index)
  291. return output
  292. class FourCrop(nn.Module):
  293. def __init__(self, max_range = 4):
  294. super(FourCrop, self).__init__()
  295. self.max_range = max_range
  296. self.prob = 0.5
  297. def forward(self, inputs):
  298. outputs = inputs
  299. for i in range(8):
  300. outputs[i] = self._crop(inputs.size(), inputs[i], i)
  301. return outputs
  302. def _crop(self, size, input, i):
  303. _, _, H, W = size
  304. h_mid = int(H / 2)
  305. w_mid = int(W / 2)
  306. if i == 0 or i == 4:
  307. corner = input[:, 0:h_mid, 0:w_mid]
  308. elif i == 1 or i == 5:
  309. corner = input[:, 0:h_mid, w_mid:]
  310. elif i == 2 or i == 6:
  311. corner = input[:, h_mid:, 0:w_mid]
  312. elif i == 3 or i == 7:
  313. corner = input[:, h_mid:, w_mid:]
  314. else:
  315. corner = input
  316. corner = transforms.Resize(size=2*h_mid)(corner)
  317. return corner
  318. class CutPerm(nn.Module):
  319. def __init__(self, max_range = 4):
  320. super(CutPerm, self).__init__()
  321. self.max_range = max_range
  322. self.prob = 0.5
  323. def forward(self, input, aug_index=None):
  324. _device = input.device
  325. _, _, H, W = input.size()
  326. if aug_index is None:
  327. aug_index = np.random.randint(4)
  328. output = self._cutperm(input, aug_index)
  329. _prob = input.new_full((input.size(0),), self.prob)
  330. _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
  331. output = _mask * input + (1 - _mask) * output
  332. else:
  333. aug_index = aug_index % self.max_range
  334. output = self._cutperm(input, aug_index)
  335. return output
  336. def _cutperm(self, inputs, aug_index):
  337. _, _, H, W = inputs.size()
  338. h_mid = int(H / 2)
  339. w_mid = int(W / 2)
  340. jigsaw_h = aug_index // 2
  341. jigsaw_v = aug_index % 2
  342. if jigsaw_h == 1:
  343. inputs =[:, :, h_mid:, :], inputs[:, :, 0:h_mid, :]), dim=2)
  344. if jigsaw_v == 1:
  345. inputs =[:, :, :, w_mid:], inputs[:, :, :, 0:w_mid]), dim=3)
  346. return inputs
  347. def assemble(a, b, c, d):
  348. ab =, b), dim=2)
  349. cd =, d), dim=2)
  350. output =, cd), dim=3)
  351. return output
  352. def quarter(inputs):
  353. _, _, H, W = inputs.size()
  354. h_mid = int(H / 2)
  355. w_mid = int(W / 2)
  356. quarters = []
  357. quarters.append(inputs[:, :, 0:h_mid, 0:w_mid])
  358. quarters.append(inputs[:, :, 0:h_mid, w_mid:])
  359. quarters.append(inputs[:, :, h_mid:, 0:w_mid])
  360. quarters.append(inputs[:, :, h_mid:, w_mid:])
  361. return quarters
  362. class HorizontalFlipLayer(nn.Module):
  363. def __init__(self):
  364. """
  365. img_size : (int, int, int)
  366. Height and width must be powers of 2. E.g. (32, 32, 1) or
  367. (64, 128, 3). Last number indicates number of channels, e.g. 1 for
  368. grayscale or 3 for RGB
  369. """
  370. super(HorizontalFlipLayer, self).__init__()
  371. _eye = torch.eye(2, 3)
  372. self.register_buffer('_eye', _eye)
  373. def forward(self, inputs):
  374. _device = inputs.device
  375. N = inputs.size(0)
  376. _theta = self._eye.repeat(N, 1, 1)
  377. r_sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1
  378. _theta[:, 0, 0] = r_sign
  379. grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device)
  380. inputs = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs)
  381. return inputs
  382. class RandomColorGrayLayer(nn.Module):
  383. def __init__(self, p):
  384. super(RandomColorGrayLayer, self).__init__()
  385. self.prob = p
  386. _weight = torch.tensor([[0.299, 0.587, 0.114]])
  387. self.register_buffer('_weight', _weight.view(1, 3, 1, 1))
  388. def forward(self, inputs, aug_index=None):
  389. if aug_index == 0:
  390. return inputs
  391. l = F.conv2d(inputs, self._weight)
  392. gray =[l, l, l], dim=1)
  393. if aug_index is None:
  394. _prob = inputs.new_full((inputs.size(0),), self.prob)
  395. _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
  396. gray = inputs * (1 - _mask) + gray * _mask
  397. return gray
  398. class ColorJitterLayer(nn.Module):
  399. def __init__(self, p, brightness, contrast, saturation, hue):
  400. super(ColorJitterLayer, self).__init__()
  401. self.prob = p
  402. self.brightness = self._check_input(brightness, 'brightness')
  403. self.contrast = self._check_input(contrast, 'contrast')
  404. self.saturation = self._check_input(saturation, 'saturation')
  405. self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
  406. clip_first_on_zero=False)
  407. def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
  408. if isinstance(value, numbers.Number):
  409. if value < 0:
  410. raise ValueError("If {} is a single number, it must be non negative.".format(name))
  411. value = [center - value, center + value]
  412. if clip_first_on_zero:
  413. value[0] = max(value[0], 0)
  414. elif isinstance(value, (tuple, list)) and len(value) == 2:
  415. if not bound[0] <= value[0] <= value[1] <= bound[1]:
  416. raise ValueError("{} values should be between {}".format(name, bound))
  417. else:
  418. raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
  419. # if value is 0 or (1., 1.) for brightness/contrast/saturation
  420. # or (0., 0.) for hue, do nothing
  421. if value[0] == value[1] == center:
  422. value = None
  423. return value
  424. def adjust_contrast(self, x):
  425. if self.contrast:
  426. factor = x.new_empty(x.size(0), 1, 1, 1).uniform_(*self.contrast)
  427. means = torch.mean(x, dim=[2, 3], keepdim=True)
  428. x = (x - means) * factor + means
  429. return torch.clamp(x, 0, 1)
  430. def adjust_hsv(self, x):
  431. f_h = x.new_zeros(x.size(0), 1, 1)
  432. f_s = x.new_ones(x.size(0), 1, 1)
  433. f_v = x.new_ones(x.size(0), 1, 1)
  434. if self.hue:
  435. f_h.uniform_(*self.hue)
  436. if self.saturation:
  437. f_s = f_s.uniform_(*self.saturation)
  438. if self.brightness:
  439. f_v = f_v.uniform_(*self.brightness)
  440. return RandomHSVFunction.apply(x, f_h, f_s, f_v)
  441. def transform(self, inputs):
  442. # Shuffle transform
  443. if np.random.rand() > 0.5:
  444. transforms = [self.adjust_contrast, self.adjust_hsv]
  445. else:
  446. transforms = [self.adjust_hsv, self.adjust_contrast]
  447. for t in transforms:
  448. inputs = t(inputs)
  449. return inputs
  450. def forward(self, inputs):
  451. _prob = inputs.new_full((inputs.size(0),), self.prob)
  452. _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
  453. return inputs * (1 - _mask) + self.transform(inputs) * _mask
  454. class RandomHSVFunction(Function):
  455. @staticmethod
  456. def forward(ctx, x, f_h, f_s, f_v):
  457. # ctx is a context object that can be used to stash information
  458. # for backward computation
  459. x = rgb2hsv(x)
  460. h = x[:, 0, :, :]
  461. h += (f_h * 255. / 360.)
  462. h = (h % 1)
  463. x[:, 0, :, :] = h
  464. x[:, 1, :, :] = x[:, 1, :, :] * f_s
  465. x[:, 2, :, :] = x[:, 2, :, :] * f_v
  466. x = torch.clamp(x, 0, 1)
  467. x = hsv2rgb(x)
  468. return x
  469. @staticmethod
  470. def backward(ctx, grad_output):
  471. # We return as many input gradients as there were arguments.
  472. # Gradients of non-Tensor arguments to forward must be None.
  473. grad_input = None
  474. if ctx.needs_input_grad[0]:
  475. grad_input = grad_output.clone()
  476. return grad_input, None, None, None
  477. class NormalizeLayer(nn.Module):
  478. """
  479. In order to certify radii in original coordinates rather than standardized coordinates, we
  480. add the Gaussian noise _before_ standardizing, which is why we have standardization be the first
  481. layer of the classifier rather than as a part of preprocessing as is typical.
  482. """
  483. def __init__(self):
  484. super(NormalizeLayer, self).__init__()
  485. def forward(self, inputs):
  486. return (inputs - 0.5) / 0.5