|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- from torch.optim.lr_scheduler import _LRScheduler
- from torch.optim.lr_scheduler import ReduceLROnPlateau
-
-
- class GradualWarmupScheduler(_LRScheduler):
- """ Gradually warm-up(increasing) learning rate in optimizer.
- Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
-
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
- total_epoch: target learning rate is reached at total_epoch, gradually
- after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
- """
-
- def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
- self.multiplier = multiplier
- if self.multiplier < 1.:
- raise ValueError('multiplier should be greater thant or equal to 1.')
- self.total_epoch = total_epoch
- self.after_scheduler = after_scheduler
- self.finished = False
- super(GradualWarmupScheduler, self).__init__(optimizer)
-
- def get_lr(self):
- if self.last_epoch > self.total_epoch:
- if self.after_scheduler:
- if not self.finished:
- self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
- self.finished = True
- return self.after_scheduler.get_lr()
- return [base_lr * self.multiplier for base_lr in self.base_lrs]
-
- if self.multiplier == 1.0:
- return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
- else:
- return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
-
- def step_ReduceLROnPlateau(self, metrics, epoch=None):
- if epoch is None:
- epoch = self.last_epoch + 1
- self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
- if self.last_epoch <= self.total_epoch:
- warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
- for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
- param_group['lr'] = lr
- else:
- if epoch is None:
- self.after_scheduler.step(metrics, None)
- else:
- self.after_scheduler.step(metrics, epoch - self.total_epoch)
-
- def step(self, epoch=None, metrics=None):
- if type(self.after_scheduler) != ReduceLROnPlateau:
- if self.finished and self.after_scheduler:
- if epoch is None:
- self.after_scheduler.step(None)
- else:
- self.after_scheduler.step(epoch - self.total_epoch)
- else:
- return super(GradualWarmupScheduler, self).step(epoch)
- else:
- self.step_ReduceLROnPlateau(metrics, epoch)
|