|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
-
- def update_learning_rate(P, optimizer, cur_epoch, n, n_total):
-
- cur_epoch = cur_epoch - 1
-
- lr = P.lr_init
- if P.optimizer == 'sgd' or 'lars':
- DECAY_RATIO = 0.1
- elif P.optimizer == 'adam':
- DECAY_RATIO = 0.3
- else:
- raise NotImplementedError()
-
- if P.warmup > 0:
- cur_iter = cur_epoch * n_total + n
- if cur_iter <= P.warmup:
- lr *= cur_iter / float(P.warmup)
-
- if cur_epoch >= 0.5 * P.epochs:
- lr *= DECAY_RATIO
- if cur_epoch >= 0.75 * P.epochs:
- lr *= DECAY_RATIO
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr
- return lr
-
-
- def _cross_entropy(input, targets, reduction='mean'):
- targets_prob = F.softmax(targets, dim=1)
- xent = (-targets_prob * F.log_softmax(input, dim=1)).sum(1)
- if reduction == 'sum':
- return xent.sum()
- elif reduction == 'mean':
- return xent.mean()
- elif reduction == 'none':
- return xent
- else:
- raise NotImplementedError()
-
-
- def _entropy(input, reduction='mean'):
- return _cross_entropy(input, input, reduction)
-
-
- def cross_entropy_soft(input, targets, reduction='mean'):
- targets_prob = F.softmax(targets, dim=1)
- xent = (-targets_prob * F.log_softmax(input, dim=1)).sum(1)
- if reduction == 'sum':
- return xent.sum()
- elif reduction == 'mean':
- return xent.mean()
- elif reduction == 'none':
- return xent
- else:
- raise NotImplementedError()
-
-
- def kl_div(input, targets, reduction='batchmean'):
- return F.kl_div(F.log_softmax(input, dim=1), F.softmax(targets, dim=1),
- reduction=reduction)
-
-
- def target_nll_loss(inputs, targets, reduction='none'):
- inputs_t = -F.nll_loss(inputs, targets, reduction='none')
- logit_diff = inputs - inputs_t.view(-1, 1)
- logit_diff = logit_diff.scatter(1, targets.view(-1, 1), -1e8)
- diff_max = logit_diff.max(1)[0]
-
- if reduction == 'sum':
- return diff_max.sum()
- elif reduction == 'mean':
- return diff_max.mean()
- elif reduction == 'none':
- return diff_max
- else:
- raise NotImplementedError()
-
-
- def target_nll_c(inputs, targets, reduction='none'):
- conf = torch.softmax(inputs, dim=1)
- conf_t = -F.nll_loss(conf, targets, reduction='none')
- conf_diff = conf - conf_t.view(-1, 1)
- conf_diff = conf_diff.scatter(1, targets.view(-1, 1), -1)
- diff_max = conf_diff.max(1)[0]
-
- if reduction == 'sum':
- return diff_max.sum()
- elif reduction == 'mean':
- return diff_max.mean()
- elif reduction == 'none':
- return diff_max
- else:
- raise NotImplementedError()
|