In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 1.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from utils.utils import Logger
  2. from utils.utils import save_checkpoint
  3. from utils.utils import save_linear_checkpoint
  4. from common.train import *
  5. from evals import test_classifier
  6. if 'sup' in P.mode:
  7. from training.sup import setup
  8. else:
  9. from training.unsup import setup
  10. train, fname = setup(P.mode, P)
  11. logger = Logger(fname, ask=not resume, local_rank=P.local_rank)
  12. logger.log(P)
  13. logger.log(model)
  14. if P.multi_gpu:
  15. linear = model.module.linear
  16. else:
  17. linear = model.linear
  18. linear_optim = torch.optim.Adam(linear.parameters(), lr=1e-3, betas=(.9, .999), weight_decay=P.weight_decay)
  19. # Run experiments
  20. for epoch in range(start_epoch, P.epochs + 1):
  21. logger.log_dirname(f"Epoch {epoch}")
  22. model.train()
  23. if P.multi_gpu:
  24. train_sampler.set_epoch(epoch)
  25. kwargs = {}
  26. kwargs['linear'] = linear
  27. kwargs['linear_optim'] = linear_optim
  28. kwargs['simclr_aug'] = simclr_aug
  29. train(P, epoch, model, criterion, optimizer, scheduler_warmup, train_loader, logger=logger, **kwargs)
  30. model.eval()
  31. if epoch % P.save_step == 0 and P.local_rank == 0:
  32. if P.multi_gpu:
  33. save_states = model.module.state_dict()
  34. else:
  35. save_states = model.state_dict()
  36. save_checkpoint(epoch, save_states, optimizer.state_dict(), logger.logdir)
  37. save_linear_checkpoint(linear_optim.state_dict(), logger.logdir)
  38. if epoch % P.error_step == 0 and ('sup' in P.mode):
  39. error = test_classifier(P, model, test_loader, epoch, logger=logger)
  40. is_best = (best > error)
  41. if is_best:
  42. best = error
  43. logger.scalar_summary('eval/best_error', best, epoch)
  44. logger.log('[Epoch %3d] [Test %5.2f] [Best %5.2f]' % (epoch, error, best))