In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.4KB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import os
  2. import pickle
  3. import random
  4. import shutil
  5. import sys
  6. from datetime import datetime
  7. import numpy as np
  8. import torch
  9. from matplotlib import pyplot as plt
  10. from tensorboardX import SummaryWriter
  11. class Logger(object):
  12. """Reference: https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514"""
  13. def __init__(self, fn, ask=True, local_rank=0):
  14. self.local_rank = local_rank
  15. if self.local_rank == 0:
  16. if not os.path.exists("./logs/"):
  17. os.mkdir("./logs/")
  18. logdir = self._make_dir(fn)
  19. if not os.path.exists(logdir):
  20. os.mkdir(logdir)
  21. if len(os.listdir(logdir)) != 0 and ask:
  22. ans = input("log_dir is not empty. All data inside log_dir will be deleted. "
  23. "Will you proceed [y/N]? ")
  24. if ans in ['y', 'Y']:
  25. shutil.rmtree(logdir)
  26. else:
  27. exit(1)
  28. self.set_dir(logdir)
  29. def _make_dir(self, fn):
  30. today = datetime.today().strftime("%y%m%d")
  31. logdir = 'logs/' + fn
  32. return logdir
  33. def set_dir(self, logdir, log_fn='log.txt'):
  34. self.logdir = logdir
  35. if not os.path.exists(logdir):
  36. os.mkdir(logdir)
  37. self.writer = SummaryWriter(logdir)
  38. self.log_file = open(os.path.join(logdir, log_fn), 'a')
  39. def log(self, string):
  40. if self.local_rank == 0:
  41. self.log_file.write('[%s] %s' % (datetime.now(), string) + '\n')
  42. self.log_file.flush()
  43. print('[%s] %s' % (datetime.now(), string))
  44. sys.stdout.flush()
  45. def log_dirname(self, string):
  46. if self.local_rank == 0:
  47. self.log_file.write('%s (%s)' % (string, self.logdir) + '\n')
  48. self.log_file.flush()
  49. print('%s (%s)' % (string, self.logdir))
  50. sys.stdout.flush()
  51. def scalar_summary(self, tag, value, step):
  52. """Log a scalar variable."""
  53. if self.local_rank == 0:
  54. self.writer.add_scalar(tag, value, step)
  55. def image_summary(self, tag, images, step):
  56. """Log a list of images."""
  57. if self.local_rank == 0:
  58. self.writer.add_image(tag, images, step)
  59. def histo_summary(self, tag, values, step):
  60. """Log a histogram of the tensor of values."""
  61. if self.local_rank == 0:
  62. self.writer.add_histogram(tag, values, step, bins='auto')
  63. class AverageMeter(object):
  64. """Computes and stores the average and current value"""
  65. def __init__(self):
  66. self.value = 0
  67. self.average = 0
  68. self.sum = 0
  69. self.count = 0
  70. def reset(self):
  71. self.value = 0
  72. self.average = 0
  73. self.sum = 0
  74. self.count = 0
  75. def update(self, value, n=1):
  76. self.value = value
  77. self.sum += value * n
  78. self.count += n
  79. self.average = self.sum / self.count
  80. def load_checkpoint(logdir, mode='last'):
  81. if mode == 'last':
  82. model_path = os.path.join(logdir, 'last.model')
  83. optim_path = os.path.join(logdir, 'last.optim')
  84. config_path = os.path.join(logdir, 'last.config')
  85. elif mode == 'best':
  86. model_path = os.path.join(logdir, 'best.model')
  87. optim_path = os.path.join(logdir, 'best.optim')
  88. config_path = os.path.join(logdir, 'best.config')
  89. else:
  90. raise NotImplementedError()
  91. print("=> Loading checkpoint from '{}'".format(logdir))
  92. if os.path.exists(model_path):
  93. model_state = torch.load(model_path)
  94. optim_state = torch.load(optim_path)
  95. with open(config_path, 'rb') as handle:
  96. cfg = pickle.load(handle)
  97. else:
  98. return None, None, None
  99. return model_state, optim_state, cfg
  100. def save_checkpoint(epoch, model_state, optim_state, logdir):
  101. last_model = os.path.join(logdir, 'last.model')
  102. last_optim = os.path.join(logdir, 'last.optim')
  103. last_config = os.path.join(logdir, 'last.config')
  104. opt = {
  105. 'epoch': epoch,
  106. }
  107. torch.save(model_state, last_model)
  108. torch.save(optim_state, last_optim)
  109. with open(last_config, 'wb') as handle:
  110. pickle.dump(opt, handle, protocol=pickle.HIGHEST_PROTOCOL)
  111. def load_linear_checkpoint(logdir, mode='last'):
  112. if mode == 'last':
  113. linear_optim_path = os.path.join(logdir, 'last.linear_optim')
  114. elif mode == 'best':
  115. linear_optim_path = os.path.join(logdir, 'best.linear_optim')
  116. else:
  117. raise NotImplementedError()
  118. print("=> Loading linear optimizer checkpoint from '{}'".format(logdir))
  119. if os.path.exists(linear_optim_path):
  120. linear_optim_state = torch.load(linear_optim_path)
  121. return linear_optim_state
  122. else:
  123. return None
  124. def save_linear_checkpoint(linear_optim_state, logdir):
  125. last_linear_optim = os.path.join(logdir, 'last.linear_optim')
  126. torch.save(linear_optim_state, last_linear_optim)
  127. def set_random_seed(seed):
  128. random.seed(seed)
  129. np.random.seed(seed)
  130. torch.manual_seed(seed)
  131. torch.cuda.manual_seed(seed)
  132. def normalize(x, dim=1, eps=1e-8):
  133. return x / (x.norm(dim=dim, keepdim=True) + eps)
  134. def make_model_diagrams(probs, labels, n_bins=10):
  135. """
  136. outputs - a torch tensor (size n x num_classes) with the outputs from the final linear layer
  137. - NOT the softmaxes
  138. labels - a torch tensor (size n) with the labels
  139. """
  140. confidences, predictions = probs.max(1)
  141. accuracies = torch.eq(predictions, labels)
  142. f, rel_ax = plt.subplots(1, 2, figsize=(4, 2.5))
  143. # Reliability diagram
  144. bins = torch.linspace(0, 1, n_bins + 1)
  145. bins[-1] = 1.0001
  146. width = bins[1] - bins[0]
  147. bin_indices = [confidences.ge(bin_lower) * confidences.lt(bin_upper) for bin_lower, bin_upper in
  148. zip(bins[:-1], bins[1:])]
  149. bin_corrects = [torch.mean(accuracies[bin_index]) for bin_index in bin_indices]
  150. bin_scores = [torch.mean(confidences[bin_index]) for bin_index in bin_indices]
  151. confs = rel_ax.bar(bins[:-1], bin_corrects.numpy(), width=width)
  152. gaps = rel_ax.bar(bins[:-1], (bin_scores - bin_corrects).numpy(), bottom=bin_corrects.numpy(), color=[1, 0.7, 0.7],
  153. alpha=0.5, width=width, hatch='//', edgecolor='r')
  154. rel_ax.plot([0, 1], [0, 1], '--', color='gray')
  155. rel_ax.legend([confs, gaps], ['Outputs', 'Gap'], loc='best', fontsize='small')
  156. # Clean up
  157. rel_ax.set_ylabel('Accuracy')
  158. rel_ax.set_xlabel('Confidence')
  159. f.tight_layout()
  160. return f