In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

classifier.py 4.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import torch.nn as nn
  2. from models.resnet import ResNet18, ResNet34, ResNet50
  3. from models.resnet_imagenet import resnet18, resnet50
  4. import models.transform_layers as TL
  5. from torchvision import transforms
  6. def get_simclr_augmentation(P, image_size):
  7. """
  8. Creates positive data for training.
  9. :param P: parsed arguments
  10. :param image_size: size of image
  11. :return: transformation
  12. """
  13. # parameter for resizecrop
  14. resize_scale = (P.resize_factor, 1.0) # resize scaling factor
  15. if P.resize_fix: # if resize_fix is True, use same scale
  16. resize_scale = (P.resize_factor, P.resize_factor)
  17. # Align augmentation
  18. s = P.color_distort
  19. color_jitter = TL.ColorJitterLayer(brightness=s*0.8, contrast=s*0.8, saturation=s*0.8, hue=s*0.2, p=0.8)
  20. color_gray = TL.RandomColorGrayLayer(p=0.2)
  21. resize_crop = TL.RandomResizedCropLayer(scale=resize_scale, size=(image_size[0], image_size[1]))
  22. #v_flip = transforms.RandomVerticalFlip()
  23. #h_flip = transforms.RandomHorizontalFlip()
  24. rand_aff = transforms.RandomAffine(degrees=360, translate=(0.2, 0.2))
  25. # Transform define #
  26. if P.dataset == 'imagenet': # Using RandomResizedCrop at PIL transform
  27. transform = nn.Sequential(
  28. color_jitter,
  29. color_gray,
  30. )
  31. elif P.dataset == 'CNMC':
  32. transform = nn.Sequential(
  33. color_jitter,
  34. color_gray,
  35. resize_crop,
  36. )
  37. else:
  38. transform = nn.Sequential(
  39. color_jitter,
  40. color_gray,
  41. resize_crop,
  42. )
  43. return transform
  44. def get_shift_module(P, eval=False):
  45. """
  46. Creates shift transformation (negative).
  47. :param P: parsed arguments
  48. :param eval: whether it is an evaluation step or not
  49. :return: transformation
  50. """
  51. if P.shift_trans_type == 'rotation':
  52. shift_transform = TL.Rotation()
  53. K_shift = 4
  54. elif P.shift_trans_type == 'cutperm':
  55. shift_transform = TL.CutPerm()
  56. K_shift = 4
  57. elif P.shift_trans_type == 'noise':
  58. shift_transform = TL.GaussNoise(mean=P.noise_mean, std=P.noise_std)
  59. K_shift = 4
  60. elif P.shift_trans_type == 'randpers':
  61. shift_transform = TL.RandPers(distortion_scale=P.distortion_scale, p=1)
  62. K_shift = 4
  63. elif P.shift_trans_type == 'sharp':
  64. shift_transform = TL.RandomAdjustSharpness(sharpness_factor=P.sharpness_factor, p=1)
  65. K_shift = 4
  66. elif P.shift_trans_type == 'blur':
  67. kernel_size = int(int(P.res.replace('px', ''))*0.1)
  68. if kernel_size%2 == 0:
  69. kernel_size+=1
  70. sigma = (0.1, float(P.blur_sigma))
  71. shift_transform = TL.GaussBlur(kernel_size=kernel_size, sigma=sigma)
  72. K_shift = 4
  73. elif P.shift_trans_type == 'blur_randpers':
  74. kernel_size = int(P.res.replace('px', '')) * 0.1
  75. sigma = (0.1, float(P.blur_sigma))
  76. shift_transform = TL.BlurRandpers(kernel_size=kernel_size, sigma=sigma, distortion_scale=P.distortion_scale, p=1)
  77. K_shift = 4
  78. elif P.shift_trans_type == 'blur_sharp':
  79. kernel_size = int(P.res.replace('px', '')) * 0.1
  80. sigma = (0.1, float(P.blur_sigma))
  81. shift_transform = TL.BlurSharpness(kernel_size=kernel_size, sigma=sigma, sharpness_factor=P.sharpness_factor, p=1)
  82. K_shift = 4
  83. elif P.shift_trans_type == 'randpers_sharp':
  84. shift_transform = TL.RandpersSharpness(distortion_scale=P.distortion_scale, p=1, sharpness_factor=P.sharpness_factor)
  85. K_shift = 4
  86. elif P.shift_trans_type == 'blur_randpers_sharp':
  87. kernel_size = int(P.res.replace('px', '')) * 0.1
  88. sigma = (0.1, float(P.blur_sigma))
  89. shift_transform = TL.BlurRandpersSharpness(kernel_size=kernel_size, sigma=sigma, distortion_scale=P.distortion_scale, p=1, sharpness_factor=P.sharpness_factor)
  90. K_shift = 4
  91. else:
  92. shift_transform = nn.Identity()
  93. K_shift = 1
  94. if not eval and not ('sup' in P.mode):
  95. assert P.batch_size == int(128/K_shift)
  96. return shift_transform, K_shift
  97. def get_shift_classifer(model, K_shift):
  98. model.shift_cls_layer = nn.Linear(model.last_dim, K_shift)
  99. return model
  100. def get_classifier(mode, n_classes=10):
  101. if mode == 'resnet18':
  102. classifier = ResNet18(num_classes=n_classes)
  103. elif mode == 'resnet34':
  104. classifier = ResNet34(num_classes=n_classes)
  105. elif mode == 'resnet50':
  106. classifier = ResNet50(num_classes=n_classes)
  107. elif mode == 'resnet18_imagenet':
  108. classifier = resnet18(num_classes=n_classes)
  109. elif mode == 'resnet50_imagenet':
  110. classifier = resnet50(num_classes=n_classes)
  111. else:
  112. raise NotImplementedError()
  113. return classifier