In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sup_linear.py 2.8KB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import time
  2. import torch.optim
  3. import torch.optim.lr_scheduler as lr_scheduler
  4. import models.transform_layers as TL
  5. from utils.utils import AverageMeter, normalize
  6. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  7. hflip = TL.HorizontalFlipLayer().to(device)
  8. def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None,
  9. simclr_aug=None, linear=None, linear_optim=None):
  10. if epoch == 1:
  11. # define optimizer and save in P (argument)
  12. milestones = [int(0.6 * P.epochs), int(0.75 * P.epochs), int(0.9 * P.epochs)]
  13. linear_optim = torch.optim.SGD(linear.parameters(),
  14. lr=1e-1, weight_decay=P.weight_decay)
  15. P.linear_optim = linear_optim
  16. P.linear_scheduler = lr_scheduler.MultiStepLR(P.linear_optim, gamma=0.1, milestones=milestones)
  17. if logger is None:
  18. log_ = print
  19. else:
  20. log_ = logger.log
  21. batch_time = AverageMeter()
  22. data_time = AverageMeter()
  23. losses = dict()
  24. losses['cls'] = AverageMeter()
  25. check = time.time()
  26. for n, (images, labels) in enumerate(loader):
  27. model.eval()
  28. count = n * P.n_gpus # number of trained samples
  29. data_time.update(time.time() - check)
  30. check = time.time()
  31. ### SimCLR loss ###
  32. if P.dataset != 'imagenet':
  33. batch_size = images.size(0)
  34. images = images.to(device)
  35. images = hflip(images) # 2B with hflip
  36. else:
  37. batch_size = images[0].size(0)
  38. images = images[0].to(device)
  39. labels = labels.to(device)
  40. images = simclr_aug(images) # simclr augmentation
  41. _, outputs_aux = model(images, penultimate=True)
  42. penultimate = outputs_aux['penultimate'].detach()
  43. outputs = linear(penultimate[0:batch_size]) # only use 0 degree samples for linear eval
  44. loss_ce = criterion(outputs, labels)
  45. ### CE loss ###
  46. P.linear_optim.zero_grad()
  47. loss_ce.backward()
  48. P.linear_optim.step()
  49. ### optimizer learning rate ###
  50. lr = P.linear_optim.param_groups[0]['lr']
  51. batch_time.update(time.time() - check)
  52. ### Log losses ###
  53. losses['cls'].update(loss_ce.item(), batch_size)
  54. if count % 50 == 0:
  55. log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n'
  56. '[LossC %f]' %
  57. (epoch, count, batch_time.value, data_time.value, lr,
  58. losses['cls'].value, ))
  59. check = time.time()
  60. P.linear_scheduler.step()
  61. log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f]' %
  62. (batch_time.average, data_time.average,
  63. losses['cls'].average))
  64. if logger is not None:
  65. logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch)
  66. logger.scalar_summary('train/batch_time', batch_time.average, epoch)