In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

LARS.py 3.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. """
  2. References:
  3. - https://github.com/PyTorchLightning/PyTorch-Lightning-Bolts/blob/master/pl_bolts/optimizers/lars_scheduling.py
  4. - https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
  5. - https://arxiv.org/pdf/1708.03888.pdf
  6. - https://github.com/noahgolmant/pytorch-lars/blob/master/lars.py
  7. """
  8. import torch
  9. from .wrapper import OptimWrapper
  10. # from torchlars._adaptive_lr import compute_adaptive_lr # Impossible to build extensions
  11. __all__ = ["LARS"]
  12. class LARS(OptimWrapper):
  13. """Implements 'LARS (Layer-wise Adaptive Rate Scaling)'__ as Optimizer a
  14. :class:`~torch.optim.Optimizer` wrapper.
  15. __ : https://arxiv.org/abs/1708.03888
  16. Wraps an arbitrary optimizer like :class:`torch.optim.SGD` to use LARS. If
  17. you want to the same performance obtained with small-batch training when
  18. you use large-batch training, LARS will be helpful::
  19. Args:
  20. optimizer (Optimizer):
  21. optimizer to wrap
  22. eps (float, optional):
  23. epsilon to help with numerical stability while calculating the
  24. adaptive learning rate
  25. trust_coef (float, optional):
  26. trust coefficient for calculating the adaptive learning rate
  27. Example::
  28. base_optimizer = optim.SGD(model.parameters(), lr=0.1)
  29. optimizer = LARS(optimizer=base_optimizer)
  30. output = model(input)
  31. loss = loss_fn(output, target)
  32. loss.backward()
  33. optimizer.step()
  34. """
  35. def __init__(self, optimizer, trust_coef=0.02, clip=True, eps=1e-8):
  36. if eps < 0.0:
  37. raise ValueError("invalid epsilon value: , %f" % eps)
  38. if trust_coef < 0.0:
  39. raise ValueError("invalid trust coefficient: %f" % trust_coef)
  40. self.optim = optimizer
  41. self.eps = eps
  42. self.trust_coef = trust_coef
  43. self.clip = clip
  44. def __getstate__(self):
  45. self.optim.__get
  46. lars_dict = {}
  47. lars_dict["trust_coef"] = self.trust_coef
  48. lars_dict["clip"] = self.clip
  49. lars_dict["eps"] = self.eps
  50. return (self.optim, lars_dict)
  51. def __setstate__(self, state):
  52. self.optim, lars_dict = state
  53. self.trust_coef = lars_dict["trust_coef"]
  54. self.clip = lars_dict["clip"]
  55. self.eps = lars_dict["eps"]
  56. @torch.no_grad()
  57. def step(self, closure=None):
  58. weight_decays = []
  59. for group in self.optim.param_groups:
  60. weight_decay = group.get("weight_decay", 0)
  61. weight_decays.append(weight_decay)
  62. # reset weight decay
  63. group["weight_decay"] = 0
  64. # update the parameters
  65. for p in group["params"]:
  66. if p.grad is not None:
  67. self.update_p(p, group, weight_decay)
  68. # update the optimizer
  69. self.optim.step(closure=closure)
  70. # return weight decay control to optimizer
  71. for group_idx, group in enumerate(self.optim.param_groups):
  72. group["weight_decay"] = weight_decays[group_idx]
  73. def update_p(self, p, group, weight_decay):
  74. # calculate new norms
  75. p_norm = torch.norm(p.data)
  76. g_norm = torch.norm(p.grad.data)
  77. if p_norm != 0 and g_norm != 0:
  78. # calculate new lr
  79. divisor = g_norm + p_norm * weight_decay + self.eps
  80. adaptive_lr = (self.trust_coef * p_norm) / divisor
  81. # clip lr
  82. if self.clip:
  83. adaptive_lr = min(adaptive_lr / group["lr"], 1)
  84. # update params with clipped lr
  85. p.grad.data += weight_decay * p.data
  86. p.grad.data *= adaptive_lr
  87. from torch.optim import SGD
  88. from pylot.util import delegates, separate_kwargs
  89. class SGDLARS(LARS):
  90. @delegates(to=LARS.__init__)
  91. @delegates(to=SGD.__init__, keep=True, but=["eps", "trust_coef"])
  92. def __init__(self, params, **kwargs):
  93. sgd_kwargs, lars_kwargs = separate_kwargs(kwargs, SGD.__init__)
  94. optim = SGD(params, **sgd_kwargs)
  95. super().__init__(optim, **lars_kwargs)