123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- """
- References:
- - https://github.com/PyTorchLightning/PyTorch-Lightning-Bolts/blob/master/pl_bolts/optimizers/lars_scheduling.py
- - https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
- - https://arxiv.org/pdf/1708.03888.pdf
- - https://github.com/noahgolmant/pytorch-lars/blob/master/lars.py
- """
-
- import torch
- from .wrapper import OptimWrapper
-
- # from torchlars._adaptive_lr import compute_adaptive_lr # Impossible to build extensions
-
-
- __all__ = ["LARS"]
-
-
- class LARS(OptimWrapper):
- """Implements 'LARS (Layer-wise Adaptive Rate Scaling)'__ as Optimizer a
- :class:`~torch.optim.Optimizer` wrapper.
- __ : https://arxiv.org/abs/1708.03888
- Wraps an arbitrary optimizer like :class:`torch.optim.SGD` to use LARS. If
- you want to the same performance obtained with small-batch training when
- you use large-batch training, LARS will be helpful::
- Args:
- optimizer (Optimizer):
- optimizer to wrap
- eps (float, optional):
- epsilon to help with numerical stability while calculating the
- adaptive learning rate
- trust_coef (float, optional):
- trust coefficient for calculating the adaptive learning rate
- Example::
- base_optimizer = optim.SGD(model.parameters(), lr=0.1)
- optimizer = LARS(optimizer=base_optimizer)
- output = model(input)
- loss = loss_fn(output, target)
- loss.backward()
- optimizer.step()
- """
-
- def __init__(self, optimizer, trust_coef=0.02, clip=True, eps=1e-8):
- if eps < 0.0:
- raise ValueError("invalid epsilon value: , %f" % eps)
- if trust_coef < 0.0:
- raise ValueError("invalid trust coefficient: %f" % trust_coef)
-
- self.optim = optimizer
- self.eps = eps
- self.trust_coef = trust_coef
- self.clip = clip
-
- def __getstate__(self):
- self.optim.__get
- lars_dict = {}
- lars_dict["trust_coef"] = self.trust_coef
- lars_dict["clip"] = self.clip
- lars_dict["eps"] = self.eps
- return (self.optim, lars_dict)
-
- def __setstate__(self, state):
- self.optim, lars_dict = state
- self.trust_coef = lars_dict["trust_coef"]
- self.clip = lars_dict["clip"]
- self.eps = lars_dict["eps"]
-
- @torch.no_grad()
- def step(self, closure=None):
- weight_decays = []
-
- for group in self.optim.param_groups:
- weight_decay = group.get("weight_decay", 0)
- weight_decays.append(weight_decay)
-
- # reset weight decay
- group["weight_decay"] = 0
-
- # update the parameters
- for p in group["params"]:
- if p.grad is not None:
- self.update_p(p, group, weight_decay)
-
- # update the optimizer
- self.optim.step(closure=closure)
-
- # return weight decay control to optimizer
- for group_idx, group in enumerate(self.optim.param_groups):
- group["weight_decay"] = weight_decays[group_idx]
-
- def update_p(self, p, group, weight_decay):
- # calculate new norms
- p_norm = torch.norm(p.data)
- g_norm = torch.norm(p.grad.data)
-
- if p_norm != 0 and g_norm != 0:
- # calculate new lr
- divisor = g_norm + p_norm * weight_decay + self.eps
- adaptive_lr = (self.trust_coef * p_norm) / divisor
-
- # clip lr
- if self.clip:
- adaptive_lr = min(adaptive_lr / group["lr"], 1)
-
- # update params with clipped lr
- p.grad.data += weight_decay * p.data
- p.grad.data *= adaptive_lr
-
-
- from torch.optim import SGD
- from pylot.util import delegates, separate_kwargs
-
-
- class SGDLARS(LARS):
- @delegates(to=LARS.__init__)
- @delegates(to=SGD.__init__, keep=True, but=["eps", "trust_coef"])
- def __init__(self, params, **kwargs):
- sgd_kwargs, lars_kwargs = separate_kwargs(kwargs, SGD.__init__)
- optim = SGD(params, **sgd_kwargs)
- super().__init__(optim, **lars_kwargs)
|