In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sup_CSI_linear.py 4.7KB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import time
  2. import torch.optim
  3. import torch.optim.lr_scheduler as lr_scheduler
  4. import models.transform_layers as TL
  5. from utils.utils import AverageMeter, normalize
  6. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  7. hflip = TL.HorizontalFlipLayer().to(device)
  8. def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None,
  9. simclr_aug=None, linear=None, linear_optim=None):
  10. if P.multi_gpu:
  11. rotation_linear = model.module.shift_cls_layer
  12. joint_linear = model.module.joint_distribution_layer
  13. else:
  14. rotation_linear = model.shift_cls_layer
  15. joint_linear = model.joint_distribution_layer
  16. if epoch == 1:
  17. # define optimizer and save in P (argument)
  18. milestones = [int(0.6 * P.epochs), int(0.75 * P.epochs), int(0.9 * P.epochs)]
  19. linear_optim = torch.optim.SGD(linear.parameters(),
  20. lr=1e-1, weight_decay=P.weight_decay)
  21. P.linear_optim = linear_optim
  22. P.linear_scheduler = lr_scheduler.MultiStepLR(P.linear_optim, gamma=0.1, milestones=milestones)
  23. rotation_linear_optim = torch.optim.SGD(rotation_linear.parameters(),
  24. lr=1e-1, weight_decay=P.weight_decay)
  25. P.rotation_linear_optim = rotation_linear_optim
  26. P.rot_scheduler = lr_scheduler.MultiStepLR(P.rotation_linear_optim, gamma=0.1, milestones=milestones)
  27. joint_linear_optim = torch.optim.SGD(joint_linear.parameters(),
  28. lr=1e-1, weight_decay=P.weight_decay)
  29. P.joint_linear_optim = joint_linear_optim
  30. P.joint_scheduler = lr_scheduler.MultiStepLR(P.joint_linear_optim, gamma=0.1, milestones=milestones)
  31. if logger is None:
  32. log_ = print
  33. else:
  34. log_ = logger.log
  35. batch_time = AverageMeter()
  36. data_time = AverageMeter()
  37. losses = dict()
  38. losses['cls'] = AverageMeter()
  39. losses['rot'] = AverageMeter()
  40. check = time.time()
  41. for n, (images, labels) in enumerate(loader):
  42. model.eval()
  43. count = n * P.n_gpus # number of trained samples
  44. data_time.update(time.time() - check)
  45. check = time.time()
  46. ### SimCLR loss ###
  47. if P.dataset != 'imagenet':
  48. batch_size = images.size(0)
  49. images = images.to(device)
  50. images = hflip(images) # 2B with hflip
  51. else:
  52. batch_size = images[0].size(0)
  53. images = images[0].to(device)
  54. labels = labels.to(device)
  55. images = torch.cat([torch.rot90(images, rot, (2, 3)) for rot in range(4)]) # 4B
  56. rot_labels = torch.cat([torch.ones_like(labels) * k for k in range(4)], 0) # B -> 4B
  57. joint_labels = torch.cat([labels + P.n_classes * i for i in range(4)], dim=0)
  58. images = simclr_aug(images) # simclr augmentation
  59. _, outputs_aux = model(images, penultimate=True)
  60. penultimate = outputs_aux['penultimate'].detach()
  61. outputs = linear(penultimate[0:batch_size]) # only use 0 degree samples for linear eval
  62. outputs_rot = rotation_linear(penultimate)
  63. outputs_joint = joint_linear(penultimate)
  64. loss_ce = criterion(outputs, labels)
  65. loss_rot = criterion(outputs_rot, rot_labels)
  66. loss_joint = criterion(outputs_joint, joint_labels)
  67. ### CE loss ###
  68. P.linear_optim.zero_grad()
  69. loss_ce.backward()
  70. P.linear_optim.step()
  71. ### Rot loss ###
  72. P.rotation_linear_optim.zero_grad()
  73. loss_rot.backward()
  74. P.rotation_linear_optim.step()
  75. ### Joint loss ###
  76. P.joint_linear_optim.zero_grad()
  77. loss_joint.backward()
  78. P.joint_linear_optim.step()
  79. ### optimizer learning rate ###
  80. lr = P.linear_optim.param_groups[0]['lr']
  81. batch_time.update(time.time() - check)
  82. ### Log losses ###
  83. losses['cls'].update(loss_ce.item(), batch_size)
  84. losses['rot'].update(loss_rot.item(), batch_size)
  85. if count % 50 == 0:
  86. log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n'
  87. '[LossC %f] [LossR %f]' %
  88. (epoch, count, batch_time.value, data_time.value, lr,
  89. losses['cls'].value, losses['rot'].value))
  90. check = time.time()
  91. P.linear_scheduler.step()
  92. P.rot_scheduler.step()
  93. P.joint_scheduler.step()
  94. log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f] [LossR %f]' %
  95. (batch_time.average, data_time.average,
  96. losses['cls'].average, losses['rot'].average))
  97. if logger is not None:
  98. logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch)
  99. logger.scalar_summary('train/loss_rot', losses['rot'].average, epoch)
  100. logger.scalar_summary('train/batch_time', batch_time.average, epoch)