In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform_layers.py 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. import math
  2. import numbers
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch.autograd import Function
  8. from torchvision import transforms
  9. if torch.__version__ >= '1.4.0':
  10. kwargs = {'align_corners': False}
  11. else:
  12. kwargs = {}
  13. def rgb2hsv(rgb):
  14. """Convert a 4-d RGB tensor to the HSV counterpart.
  15. Here, we compute hue using atan2() based on the definition in [1],
  16. instead of using the common lookup table approach as in [2, 3].
  17. Those values agree when the angle is a multiple of 30°,
  18. otherwise they may differ at most ~1.2°.
  19. References
  20. [1] https://en.wikipedia.org/wiki/Hue
  21. [2] https://www.rapidtables.com/convert/color/rgb-to-hsv.html
  22. [3] https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L212
  23. """
  24. r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :]
  25. Cmax = rgb.max(1)[0]
  26. Cmin = rgb.min(1)[0]
  27. delta = Cmax - Cmin
  28. hue = torch.atan2(math.sqrt(3) * (g - b), 2 * r - g - b)
  29. hue = (hue % (2 * math.pi)) / (2 * math.pi)
  30. saturate = delta / Cmax
  31. value = Cmax
  32. hsv = torch.stack([hue, saturate, value], dim=1)
  33. hsv[~torch.isfinite(hsv)] = 0.
  34. return hsv
  35. def hsv2rgb(hsv):
  36. """Convert a 4-d HSV tensor to the RGB counterpart.
  37. >>> %timeit hsv2rgb(hsv)
  38. 2.37 ms ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
  39. >>> %timeit rgb2hsv_fast(rgb)
  40. 298 µs ± 542 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
  41. >>> torch.allclose(hsv2rgb(hsv), hsv2rgb_fast(hsv), atol=1e-6)
  42. True
  43. References
  44. [1] https://en.wikipedia.org/wiki/HSL_and_HSV#HSV_to_RGB_alternative
  45. """
  46. h, s, v = hsv[:, [0]], hsv[:, [1]], hsv[:, [2]]
  47. c = v * s
  48. n = hsv.new_tensor([5, 3, 1]).view(3, 1, 1)
  49. k = (n + h * 6) % 6
  50. t = torch.min(k, 4 - k)
  51. t = torch.clamp(t, 0, 1)
  52. return v - c * t
  53. class RandomResizedCropLayer(nn.Module):
  54. def __init__(self, size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.)):
  55. '''
  56. Inception Crop
  57. size (tuple): size of fowarding image (C, W, H)
  58. scale (tuple): range of size of the origin size cropped
  59. ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
  60. '''
  61. super(RandomResizedCropLayer, self).__init__()
  62. _eye = torch.eye(2, 3)
  63. self.size = size
  64. self.register_buffer('_eye', _eye)
  65. self.scale = scale
  66. self.ratio = ratio
  67. def forward(self, inputs, whbias=None):
  68. _device = inputs.device
  69. N = inputs.size(0)
  70. _theta = self._eye.repeat(N, 1, 1)
  71. if whbias is None:
  72. whbias = self._sample_latent(inputs)
  73. _theta[:, 0, 0] = whbias[:, 0]
  74. _theta[:, 1, 1] = whbias[:, 1]
  75. _theta[:, 0, 2] = whbias[:, 2]
  76. _theta[:, 1, 2] = whbias[:, 3]
  77. grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device)
  78. output = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs)
  79. if self.size is not None:
  80. output = F.adaptive_avg_pool2d(output, self.size)
  81. # output = F.adaptive_avg_pool2d(output, self.size)
  82. # output = F.adaptive_avg_pool2d(output, (self.size[0], self.size[1]))
  83. return output
  84. def _clamp(self, whbias):
  85. w = whbias[:, 0]
  86. h = whbias[:, 1]
  87. w_bias = whbias[:, 2]
  88. h_bias = whbias[:, 3]
  89. # Clamp with scale
  90. w = torch.clamp(w, *self.scale)
  91. h = torch.clamp(h, *self.scale)
  92. # Clamp with ratio
  93. w = self.ratio[0] * h + torch.relu(w - self.ratio[0] * h)
  94. w = self.ratio[1] * h - torch.relu(self.ratio[1] * h - w)
  95. # Clamp with bias range: w_bias \in (w - 1, 1 - w), h_bias \in (h - 1, 1 - h)
  96. w_bias = w - 1 + torch.relu(w_bias - w + 1)
  97. w_bias = 1 - w - torch.relu(1 - w - w_bias)
  98. h_bias = h - 1 + torch.relu(h_bias - h + 1)
  99. h_bias = 1 - h - torch.relu(1 - h - h_bias)
  100. whbias = torch.stack([w, h, w_bias, h_bias], dim=0).t()
  101. return whbias
  102. def _sample_latent(self, inputs):
  103. _device = inputs.device
  104. N, _, width, height = inputs.shape
  105. # N * 10 trial
  106. area = width * height
  107. target_area = np.random.uniform(*self.scale, N * 10) * area
  108. log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
  109. aspect_ratio = np.exp(np.random.uniform(*log_ratio, N * 10))
  110. # If doesn't satisfy ratio condition, then do central crop
  111. w = np.round(np.sqrt(target_area * aspect_ratio))
  112. h = np.round(np.sqrt(target_area / aspect_ratio))
  113. cond = (0 < w) * (w <= width) * (0 < h) * (h <= height)
  114. w = w[cond]
  115. h = h[cond]
  116. cond_len = w.shape[0]
  117. if cond_len >= N:
  118. w = w[:N]
  119. h = h[:N]
  120. else:
  121. w = np.concatenate([w, np.ones(N - cond_len) * width])
  122. h = np.concatenate([h, np.ones(N - cond_len) * height])
  123. w_bias = np.random.randint(w - width, width - w + 1) / width
  124. h_bias = np.random.randint(h - height, height - h + 1) / height
  125. w = w / width
  126. h = h / height
  127. whbias = np.column_stack([w, h, w_bias, h_bias])
  128. whbias = torch.tensor(whbias, device=_device)
  129. return whbias
  130. class HorizontalFlipRandomCrop(nn.Module):
  131. def __init__(self, max_range):
  132. super(HorizontalFlipRandomCrop, self).__init__()
  133. self.max_range = max_range
  134. _eye = torch.eye(2, 3)
  135. self.register_buffer('_eye', _eye)
  136. def forward(self, input, sign=None, bias=None, rotation=None):
  137. _device = input.device
  138. N = input.size(0)
  139. _theta = self._eye.repeat(N, 1, 1)
  140. if sign is None:
  141. sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1
  142. if bias is None:
  143. bias = torch.empty((N, 2), device=_device).uniform_(-self.max_range, self.max_range)
  144. _theta[:, 0, 0] = sign
  145. _theta[:, :, 2] = bias
  146. if rotation is not None:
  147. _theta[:, 0:2, 0:2] = rotation
  148. grid = F.affine_grid(_theta, input.size(), **kwargs).to(_device)
  149. output = F.grid_sample(input, grid, padding_mode='reflection', **kwargs)
  150. return output
  151. def _sample_latent(self, N, device=None):
  152. sign = torch.bernoulli(torch.ones(N, device=device) * 0.5) * 2 - 1
  153. bias = torch.empty((N, 2), device=device).uniform_(-self.max_range, self.max_range)
  154. return sign, bias
  155. class Rotation(nn.Module):
  156. def __init__(self, max_range = 4):
  157. super(Rotation, self).__init__()
  158. self.max_range = max_range
  159. self.prob = 0.5
  160. def forward(self, input, aug_index=None):
  161. _device = input.device
  162. _, _, H, W = input.size()
  163. if aug_index is None:
  164. aug_index = np.random.randint(4)
  165. output = torch.rot90(input, aug_index, (2, 3))
  166. _prob = input.new_full((input.size(0),), self.prob)
  167. _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
  168. output = _mask * input + (1-_mask) * output
  169. else:
  170. aug_index = aug_index % self.max_range
  171. output = torch.rot90(input, aug_index, (2, 3))
  172. return output
  173. class RandomAdjustSharpness(nn.Module):
  174. def __init__(self, sharpness_factor=0.5, p=0.5):
  175. super(RandomAdjustSharpness, self).__init__()
  176. self.sharpness_factor = sharpness_factor
  177. self.prob = p
  178. def forward(self, input, aug_index=None):
  179. _device = input.device
  180. _, _, H, W = input.size()
  181. if aug_index == 0:
  182. output = input
  183. else:
  184. output = transforms.RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.prob)(input)
  185. return output
  186. class RandPers(nn.Module):
  187. def __init__(self, distortion_scale=0.5, p=0.5):
  188. super(RandPers, self).__init__()
  189. self.distortion_scale = distortion_scale
  190. self.prob = p
  191. def forward(self, input, aug_index=None):
  192. _device = input.device
  193. _, _, H, W = input.size()
  194. if aug_index == 0:
  195. output = input
  196. else:
  197. output = transforms.RandomPerspective(distortion_scale=self.distortion_scale, p=self.prob)(input)
  198. return output
  199. class GaussBlur(nn.Module):
  200. def __init__(self, max_range = 4, kernel_size=3, sigma=(0.1, 2.0)):
  201. super(GaussBlur, self).__init__()
  202. self.max_range = max_range
  203. self.prob = 0.5
  204. self.sigma = sigma
  205. self.kernel_size = kernel_size
  206. def forward(self, input, aug_index=None):
  207. _device = input.device
  208. _, _, H, W = input.size()
  209. if aug_index is None:
  210. aug_index = np.random.randint(4)
  211. output = transforms.GaussianBlur(kernel_size=13, sigma=abs(aug_index)+1)(input)
  212. _prob = input.new_full((input.size(0),), self.prob)
  213. _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
  214. output = _mask * input + (1-_mask) * output
  215. else:
  216. if aug_index == 0:
  217. output = input
  218. else:
  219. output = transforms.GaussianBlur(kernel_size=self.kernel_size, sigma=self.sigma)(input)
  220. return output
  221. class GaussNoise(nn.Module):
  222. def __init__(self, mean = 0, std = 1):
  223. super(GaussNoise, self).__init__()
  224. self.mean = mean
  225. self.std = std
  226. def forward(self, input, aug_index=None):
  227. _device = input.device
  228. _, _, H, W = input.size()
  229. if aug_index == 0:
  230. output = input
  231. else:
  232. output = input + (torch.randn(input.size()) * self.std + self.mean).to(_device)
  233. return output
  234. class BlurRandpers(nn.Module):
  235. def __init__(self, max_range=2, kernel_size=3, sigma=(10, 20), distortion_scale=0.6, p=1):
  236. super(BlurRandpers, self).__init__()
  237. self.max_range = max_range
  238. self.sigma = sigma
  239. self.kernel_size = kernel_size
  240. self.distortion_scale = distortion_scale
  241. self.p = p
  242. self.gauss = GaussBlur(kernel_size=self.kernel_size, sigma=self.sigma)
  243. self.randpers = RandPers(distortion_scale=self.distortion_scale, p=self.p)
  244. def forward(self, input, aug_index=None):
  245. output = self.gauss.forward(input=input, aug_index=aug_index)
  246. output = self.randpers.forward(input=output, aug_index=aug_index)
  247. return output
  248. class BlurSharpness(nn.Module):
  249. def __init__(self, max_range=2, kernel_size=3, sigma=(10, 20), sharpness_factor=0.6, p=1):
  250. super(BlurSharpness, self).__init__()
  251. self.max_range = max_range
  252. self.sigma = sigma
  253. self.kernel_size = kernel_size
  254. self.sharpness_factor = sharpness_factor
  255. self.p = p
  256. self.gauss = GaussBlur(kernel_size=self.kernel_size, sigma=self.sigma)
  257. self.sharp = RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.p)
  258. def forward(self, input, aug_index=None):
  259. output = self.gauss.forward(input=input, aug_index=aug_index)
  260. output = self.sharp.forward(input=output, aug_index=aug_index)
  261. return output
  262. class RandpersSharpness(nn.Module):
  263. def __init__(self, max_range=2, distortion_scale=0.6, p=1, sharpness_factor=0.6):
  264. super(RandpersSharpness, self).__init__()
  265. self.max_range = max_range
  266. self.distortion_scale = distortion_scale
  267. self.p = p
  268. self.sharpness_factor = sharpness_factor
  269. self.randpers = RandPers(distortion_scale=self.distortion_scale, p=self.p)
  270. self.sharp = RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.p)
  271. def forward(self, input, aug_index=None):
  272. output = self.randpers.forward(input=input, aug_index=aug_index)
  273. output = self.sharp.forward(input=output, aug_index=aug_index)
  274. return output
  275. class BlurRandpersSharpness(nn.Module):
  276. def __init__(self, max_range=2, kernel_size=3, sigma=(10, 20), distortion_scale=0.6, p=1, sharpness_factor=0.6):
  277. super(BlurRandpersSharpness, self).__init__()
  278. self.max_range = max_range
  279. self.sigma = sigma
  280. self.kernel_size = kernel_size
  281. self.distortion_scale = distortion_scale
  282. self.p = p
  283. self.sharpness_factor = sharpness_factor
  284. self.gauss = GaussBlur(kernel_size=self.kernel_size, sigma=self.sigma)
  285. self.randpers = RandPers(distortion_scale=self.distortion_scale, p=self.p)
  286. self.sharp = RandomAdjustSharpness(sharpness_factor=self.sharpness_factor, p=self.p)
  287. def forward(self, input, aug_index=None):
  288. output = self.gauss.forward(input=input, aug_index=aug_index)
  289. output = self.randpers.forward(input=output, aug_index=aug_index)
  290. output = self.sharp.forward(input=output, aug_index=aug_index)
  291. return output
  292. class FourCrop(nn.Module):
  293. def __init__(self, max_range = 4):
  294. super(FourCrop, self).__init__()
  295. self.max_range = max_range
  296. self.prob = 0.5
  297. def forward(self, inputs):
  298. outputs = inputs
  299. for i in range(8):
  300. outputs[i] = self._crop(inputs.size(), inputs[i], i)
  301. return outputs
  302. def _crop(self, size, input, i):
  303. _, _, H, W = size
  304. h_mid = int(H / 2)
  305. w_mid = int(W / 2)
  306. if i == 0 or i == 4:
  307. corner = input[:, 0:h_mid, 0:w_mid]
  308. elif i == 1 or i == 5:
  309. corner = input[:, 0:h_mid, w_mid:]
  310. elif i == 2 or i == 6:
  311. corner = input[:, h_mid:, 0:w_mid]
  312. elif i == 3 or i == 7:
  313. corner = input[:, h_mid:, w_mid:]
  314. else:
  315. corner = input
  316. corner = transforms.Resize(size=2*h_mid)(corner)
  317. return corner
  318. class CutPerm(nn.Module):
  319. def __init__(self, max_range = 4):
  320. super(CutPerm, self).__init__()
  321. self.max_range = max_range
  322. self.prob = 0.5
  323. def forward(self, input, aug_index=None):
  324. _device = input.device
  325. _, _, H, W = input.size()
  326. if aug_index is None:
  327. aug_index = np.random.randint(4)
  328. output = self._cutperm(input, aug_index)
  329. _prob = input.new_full((input.size(0),), self.prob)
  330. _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
  331. output = _mask * input + (1 - _mask) * output
  332. else:
  333. aug_index = aug_index % self.max_range
  334. output = self._cutperm(input, aug_index)
  335. return output
  336. def _cutperm(self, inputs, aug_index):
  337. _, _, H, W = inputs.size()
  338. h_mid = int(H / 2)
  339. w_mid = int(W / 2)
  340. jigsaw_h = aug_index // 2
  341. jigsaw_v = aug_index % 2
  342. if jigsaw_h == 1:
  343. inputs = torch.cat((inputs[:, :, h_mid:, :], inputs[:, :, 0:h_mid, :]), dim=2)
  344. if jigsaw_v == 1:
  345. inputs = torch.cat((inputs[:, :, :, w_mid:], inputs[:, :, :, 0:w_mid]), dim=3)
  346. return inputs
  347. def assemble(a, b, c, d):
  348. ab = torch.cat((a, b), dim=2)
  349. cd = torch.cat((c, d), dim=2)
  350. output = torch.cat((ab, cd), dim=3)
  351. return output
  352. def quarter(inputs):
  353. _, _, H, W = inputs.size()
  354. h_mid = int(H / 2)
  355. w_mid = int(W / 2)
  356. quarters = []
  357. quarters.append(inputs[:, :, 0:h_mid, 0:w_mid])
  358. quarters.append(inputs[:, :, 0:h_mid, w_mid:])
  359. quarters.append(inputs[:, :, h_mid:, 0:w_mid])
  360. quarters.append(inputs[:, :, h_mid:, w_mid:])
  361. return quarters
  362. class HorizontalFlipLayer(nn.Module):
  363. def __init__(self):
  364. """
  365. img_size : (int, int, int)
  366. Height and width must be powers of 2. E.g. (32, 32, 1) or
  367. (64, 128, 3). Last number indicates number of channels, e.g. 1 for
  368. grayscale or 3 for RGB
  369. """
  370. super(HorizontalFlipLayer, self).__init__()
  371. _eye = torch.eye(2, 3)
  372. self.register_buffer('_eye', _eye)
  373. def forward(self, inputs):
  374. _device = inputs.device
  375. N = inputs.size(0)
  376. _theta = self._eye.repeat(N, 1, 1)
  377. r_sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1
  378. _theta[:, 0, 0] = r_sign
  379. grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device)
  380. inputs = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs)
  381. return inputs
  382. class RandomColorGrayLayer(nn.Module):
  383. def __init__(self, p):
  384. super(RandomColorGrayLayer, self).__init__()
  385. self.prob = p
  386. _weight = torch.tensor([[0.299, 0.587, 0.114]])
  387. self.register_buffer('_weight', _weight.view(1, 3, 1, 1))
  388. def forward(self, inputs, aug_index=None):
  389. if aug_index == 0:
  390. return inputs
  391. l = F.conv2d(inputs, self._weight)
  392. gray = torch.cat([l, l, l], dim=1)
  393. if aug_index is None:
  394. _prob = inputs.new_full((inputs.size(0),), self.prob)
  395. _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
  396. gray = inputs * (1 - _mask) + gray * _mask
  397. return gray
  398. class ColorJitterLayer(nn.Module):
  399. def __init__(self, p, brightness, contrast, saturation, hue):
  400. super(ColorJitterLayer, self).__init__()
  401. self.prob = p
  402. self.brightness = self._check_input(brightness, 'brightness')
  403. self.contrast = self._check_input(contrast, 'contrast')
  404. self.saturation = self._check_input(saturation, 'saturation')
  405. self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
  406. clip_first_on_zero=False)
  407. def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
  408. if isinstance(value, numbers.Number):
  409. if value < 0:
  410. raise ValueError("If {} is a single number, it must be non negative.".format(name))
  411. value = [center - value, center + value]
  412. if clip_first_on_zero:
  413. value[0] = max(value[0], 0)
  414. elif isinstance(value, (tuple, list)) and len(value) == 2:
  415. if not bound[0] <= value[0] <= value[1] <= bound[1]:
  416. raise ValueError("{} values should be between {}".format(name, bound))
  417. else:
  418. raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
  419. # if value is 0 or (1., 1.) for brightness/contrast/saturation
  420. # or (0., 0.) for hue, do nothing
  421. if value[0] == value[1] == center:
  422. value = None
  423. return value
  424. def adjust_contrast(self, x):
  425. if self.contrast:
  426. factor = x.new_empty(x.size(0), 1, 1, 1).uniform_(*self.contrast)
  427. means = torch.mean(x, dim=[2, 3], keepdim=True)
  428. x = (x - means) * factor + means
  429. return torch.clamp(x, 0, 1)
  430. def adjust_hsv(self, x):
  431. f_h = x.new_zeros(x.size(0), 1, 1)
  432. f_s = x.new_ones(x.size(0), 1, 1)
  433. f_v = x.new_ones(x.size(0), 1, 1)
  434. if self.hue:
  435. f_h.uniform_(*self.hue)
  436. if self.saturation:
  437. f_s = f_s.uniform_(*self.saturation)
  438. if self.brightness:
  439. f_v = f_v.uniform_(*self.brightness)
  440. return RandomHSVFunction.apply(x, f_h, f_s, f_v)
  441. def transform(self, inputs):
  442. # Shuffle transform
  443. if np.random.rand() > 0.5:
  444. transforms = [self.adjust_contrast, self.adjust_hsv]
  445. else:
  446. transforms = [self.adjust_hsv, self.adjust_contrast]
  447. for t in transforms:
  448. inputs = t(inputs)
  449. return inputs
  450. def forward(self, inputs):
  451. _prob = inputs.new_full((inputs.size(0),), self.prob)
  452. _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
  453. return inputs * (1 - _mask) + self.transform(inputs) * _mask
  454. class RandomHSVFunction(Function):
  455. @staticmethod
  456. def forward(ctx, x, f_h, f_s, f_v):
  457. # ctx is a context object that can be used to stash information
  458. # for backward computation
  459. x = rgb2hsv(x)
  460. h = x[:, 0, :, :]
  461. h += (f_h * 255. / 360.)
  462. h = (h % 1)
  463. x[:, 0, :, :] = h
  464. x[:, 1, :, :] = x[:, 1, :, :] * f_s
  465. x[:, 2, :, :] = x[:, 2, :, :] * f_v
  466. x = torch.clamp(x, 0, 1)
  467. x = hsv2rgb(x)
  468. return x
  469. @staticmethod
  470. def backward(ctx, grad_output):
  471. # We return as many input gradients as there were arguments.
  472. # Gradients of non-Tensor arguments to forward must be None.
  473. grad_input = None
  474. if ctx.needs_input_grad[0]:
  475. grad_input = grad_output.clone()
  476. return grad_input, None, None, None
  477. class NormalizeLayer(nn.Module):
  478. """
  479. In order to certify radii in original coordinates rather than standardized coordinates, we
  480. add the Gaussian noise _before_ standardizing, which is why we have standardization be the first
  481. layer of the classifier rather than as a part of preprocessing as is typical.
  482. """
  483. def __init__(self):
  484. super(NormalizeLayer, self).__init__()
  485. def forward(self, inputs):
  486. return (inputs - 0.5) / 0.5