In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

scheduler.py 3.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from torch.optim.lr_scheduler import _LRScheduler
  2. from torch.optim.lr_scheduler import ReduceLROnPlateau
  3. class GradualWarmupScheduler(_LRScheduler):
  4. """ Gradually warm-up(increasing) learning rate in optimizer.
  5. Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
  6. Args:
  7. optimizer (Optimizer): Wrapped optimizer.
  8. multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
  9. total_epoch: target learning rate is reached at total_epoch, gradually
  10. after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
  11. """
  12. def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
  13. self.multiplier = multiplier
  14. if self.multiplier < 1.:
  15. raise ValueError('multiplier should be greater thant or equal to 1.')
  16. self.total_epoch = total_epoch
  17. self.after_scheduler = after_scheduler
  18. self.finished = False
  19. super(GradualWarmupScheduler, self).__init__(optimizer)
  20. def get_lr(self):
  21. if self.last_epoch > self.total_epoch:
  22. if self.after_scheduler:
  23. if not self.finished:
  24. self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
  25. self.finished = True
  26. return self.after_scheduler.get_lr()
  27. return [base_lr * self.multiplier for base_lr in self.base_lrs]
  28. if self.multiplier == 1.0:
  29. return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
  30. else:
  31. return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
  32. def step_ReduceLROnPlateau(self, metrics, epoch=None):
  33. if epoch is None:
  34. epoch = self.last_epoch + 1
  35. self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
  36. if self.last_epoch <= self.total_epoch:
  37. warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
  38. for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
  39. param_group['lr'] = lr
  40. else:
  41. if epoch is None:
  42. self.after_scheduler.step(metrics, None)
  43. else:
  44. self.after_scheduler.step(metrics, epoch - self.total_epoch)
  45. def step(self, epoch=None, metrics=None):
  46. if type(self.after_scheduler) != ReduceLROnPlateau:
  47. if self.finished and self.after_scheduler:
  48. if epoch is None:
  49. self.after_scheduler.step(None)
  50. else:
  51. self.after_scheduler.step(epoch - self.total_epoch)
  52. else:
  53. return super(GradualWarmupScheduler, self).step(epoch)
  54. else:
  55. self.step_ReduceLROnPlateau(metrics, epoch)