In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 2.7KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. def update_learning_rate(P, optimizer, cur_epoch, n, n_total):
  5. cur_epoch = cur_epoch - 1
  6. lr = P.lr_init
  7. if P.optimizer == 'sgd' or 'lars':
  8. DECAY_RATIO = 0.1
  9. elif P.optimizer == 'adam':
  10. DECAY_RATIO = 0.3
  11. else:
  12. raise NotImplementedError()
  13. if P.warmup > 0:
  14. cur_iter = cur_epoch * n_total + n
  15. if cur_iter <= P.warmup:
  16. lr *= cur_iter / float(P.warmup)
  17. if cur_epoch >= 0.5 * P.epochs:
  18. lr *= DECAY_RATIO
  19. if cur_epoch >= 0.75 * P.epochs:
  20. lr *= DECAY_RATIO
  21. for param_group in optimizer.param_groups:
  22. param_group['lr'] = lr
  23. return lr
  24. def _cross_entropy(input, targets, reduction='mean'):
  25. targets_prob = F.softmax(targets, dim=1)
  26. xent = (-targets_prob * F.log_softmax(input, dim=1)).sum(1)
  27. if reduction == 'sum':
  28. return xent.sum()
  29. elif reduction == 'mean':
  30. return xent.mean()
  31. elif reduction == 'none':
  32. return xent
  33. else:
  34. raise NotImplementedError()
  35. def _entropy(input, reduction='mean'):
  36. return _cross_entropy(input, input, reduction)
  37. def cross_entropy_soft(input, targets, reduction='mean'):
  38. targets_prob = F.softmax(targets, dim=1)
  39. xent = (-targets_prob * F.log_softmax(input, dim=1)).sum(1)
  40. if reduction == 'sum':
  41. return xent.sum()
  42. elif reduction == 'mean':
  43. return xent.mean()
  44. elif reduction == 'none':
  45. return xent
  46. else:
  47. raise NotImplementedError()
  48. def kl_div(input, targets, reduction='batchmean'):
  49. return F.kl_div(F.log_softmax(input, dim=1), F.softmax(targets, dim=1),
  50. reduction=reduction)
  51. def target_nll_loss(inputs, targets, reduction='none'):
  52. inputs_t = -F.nll_loss(inputs, targets, reduction='none')
  53. logit_diff = inputs - inputs_t.view(-1, 1)
  54. logit_diff = logit_diff.scatter(1, targets.view(-1, 1), -1e8)
  55. diff_max = logit_diff.max(1)[0]
  56. if reduction == 'sum':
  57. return diff_max.sum()
  58. elif reduction == 'mean':
  59. return diff_max.mean()
  60. elif reduction == 'none':
  61. return diff_max
  62. else:
  63. raise NotImplementedError()
  64. def target_nll_c(inputs, targets, reduction='none'):
  65. conf = torch.softmax(inputs, dim=1)
  66. conf_t = -F.nll_loss(conf, targets, reduction='none')
  67. conf_diff = conf - conf_t.view(-1, 1)
  68. conf_diff = conf_diff.scatter(1, targets.view(-1, 1), -1)
  69. diff_max = conf_diff.max(1)[0]
  70. if reduction == 'sum':
  71. return diff_max.sum()
  72. elif reduction == 'mean':
  73. return diff_max.mean()
  74. elif reduction == 'none':
  75. return diff_max
  76. else:
  77. raise NotImplementedError()