In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sup_simclr.py 3.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import time
  2. import torch.optim
  3. import models.transform_layers as TL
  4. from training.contrastive_loss import get_similarity_matrix, Supervised_NT_xent
  5. from utils.utils import AverageMeter, normalize
  6. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  7. hflip = TL.HorizontalFlipLayer().to(device)
  8. def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None,
  9. simclr_aug=None, linear=None, linear_optim=None):
  10. assert simclr_aug is not None
  11. assert P.sim_lambda == 1.0
  12. if logger is None:
  13. log_ = print
  14. else:
  15. log_ = logger.log
  16. batch_time = AverageMeter()
  17. data_time = AverageMeter()
  18. losses = dict()
  19. losses['cls'] = AverageMeter()
  20. losses['sim'] = AverageMeter()
  21. losses['simnorm'] = AverageMeter()
  22. check = time.time()
  23. for n, (images, labels) in enumerate(loader):
  24. model.train()
  25. count = n * P.n_gpus # number of trained samples
  26. data_time.update(time.time() - check)
  27. check = time.time()
  28. ### SimCLR loss ###
  29. if P.dataset != 'imagenet' and P.dataset != 'CNMC' and P.dataset != 'CNMC_grayscale':
  30. batch_size = images.size(0)
  31. images = images.to(device)
  32. images_pair = hflip(images.repeat(2, 1, 1, 1)) # 2B with hflip
  33. else:
  34. batch_size = images[0].size(0)
  35. images1, images2 = images[0].to(device), images[1].to(device)
  36. images_pair = torch.cat([images1, images2], dim=0) # 2B
  37. labels = labels.to(device)
  38. images_pair = simclr_aug(images_pair) # simclr augmentation
  39. _, outputs_aux = model(images_pair, simclr=True, penultimate=True)
  40. simclr = normalize(outputs_aux['simclr']) # normalize
  41. sim_matrix = get_similarity_matrix(simclr, multi_gpu=P.multi_gpu)
  42. loss_sim = Supervised_NT_xent(sim_matrix, labels=labels, temperature=0.07, multi_gpu=P.multi_gpu) * P.sim_lambda
  43. ### total loss ###
  44. loss = loss_sim
  45. optimizer.zero_grad()
  46. loss.backward()
  47. optimizer.step()
  48. scheduler.step(epoch - 1 + n / len(loader))
  49. lr = optimizer.param_groups[0]['lr']
  50. batch_time.update(time.time() - check)
  51. ### Post-processing stuffs ###
  52. simclr_norm = outputs_aux['simclr'].norm(dim=1).mean()
  53. ### Linear evaluation ###
  54. outputs_linear_eval = linear(outputs_aux['penultimate'].detach())
  55. loss_linear = criterion(outputs_linear_eval, labels.repeat(2))
  56. linear_optim.zero_grad()
  57. loss_linear.backward()
  58. linear_optim.step()
  59. ### Log losses ###
  60. losses['cls'].update(0, batch_size)
  61. losses['sim'].update(loss_sim.item(), batch_size)
  62. losses['simnorm'].update(simclr_norm.item(), batch_size)
  63. if count % 50 == 0:
  64. log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n'
  65. '[LossC %f] [LossSim %f] [SimNorm %f]' %
  66. (epoch, count, batch_time.value, data_time.value, lr,
  67. losses['cls'].value, losses['sim'].value, losses['simnorm'].value))
  68. check = time.time()
  69. log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f] [LossSim %f] [SimNorm %f]' %
  70. (batch_time.average, data_time.average,
  71. losses['cls'].average, losses['sim'].average, losses['simnorm'].average))
  72. if logger is not None:
  73. logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch)
  74. logger.scalar_summary('train/loss_sim', losses['sim'].average, epoch)
  75. logger.scalar_summary('train/batch_time', batch_time.average, epoch)
  76. logger.scalar_summary('train/simclr_norm', losses['simnorm'].average, epoch)