123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- import os
- import pickle
- import random
- import shutil
- import sys
- from datetime import datetime
-
- import numpy as np
- import torch
- from matplotlib import pyplot as plt
- from tensorboardX import SummaryWriter
-
-
- class Logger(object):
- """Reference: https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514"""
-
- def __init__(self, fn, ask=True, local_rank=0):
- self.local_rank = local_rank
- if self.local_rank == 0:
- if not os.path.exists("./logs/"):
- os.mkdir("./logs/")
-
- logdir = self._make_dir(fn)
- if not os.path.exists(logdir):
- os.mkdir(logdir)
-
- if len(os.listdir(logdir)) != 0 and ask:
- ans = input("log_dir is not empty. All data inside log_dir will be deleted. "
- "Will you proceed [y/N]? ")
- if ans in ['y', 'Y']:
- shutil.rmtree(logdir)
- else:
- exit(1)
-
- self.set_dir(logdir)
-
- def _make_dir(self, fn):
- today = datetime.today().strftime("%y%m%d")
- logdir = 'logs/' + fn
- return logdir
-
- def set_dir(self, logdir, log_fn='log.txt'):
- self.logdir = logdir
- if not os.path.exists(logdir):
- os.mkdir(logdir)
- self.writer = SummaryWriter(logdir)
- self.log_file = open(os.path.join(logdir, log_fn), 'a')
-
- def log(self, string):
- if self.local_rank == 0:
- self.log_file.write('[%s] %s' % (datetime.now(), string) + '\n')
- self.log_file.flush()
-
- print('[%s] %s' % (datetime.now(), string))
- sys.stdout.flush()
-
- def log_dirname(self, string):
- if self.local_rank == 0:
- self.log_file.write('%s (%s)' % (string, self.logdir) + '\n')
- self.log_file.flush()
-
- print('%s (%s)' % (string, self.logdir))
- sys.stdout.flush()
-
- def scalar_summary(self, tag, value, step):
- """Log a scalar variable."""
- if self.local_rank == 0:
- self.writer.add_scalar(tag, value, step)
-
- def image_summary(self, tag, images, step):
- """Log a list of images."""
- if self.local_rank == 0:
- self.writer.add_image(tag, images, step)
-
- def histo_summary(self, tag, values, step):
- """Log a histogram of the tensor of values."""
- if self.local_rank == 0:
- self.writer.add_histogram(tag, values, step, bins='auto')
-
-
- class AverageMeter(object):
- """Computes and stores the average and current value"""
-
- def __init__(self):
- self.value = 0
- self.average = 0
- self.sum = 0
- self.count = 0
-
- def reset(self):
- self.value = 0
- self.average = 0
- self.sum = 0
- self.count = 0
-
- def update(self, value, n=1):
- self.value = value
- self.sum += value * n
- self.count += n
- self.average = self.sum / self.count
-
-
- def load_checkpoint(logdir, mode='last'):
- if mode == 'last':
- model_path = os.path.join(logdir, 'last.model')
- optim_path = os.path.join(logdir, 'last.optim')
- config_path = os.path.join(logdir, 'last.config')
- elif mode == 'best':
- model_path = os.path.join(logdir, 'best.model')
- optim_path = os.path.join(logdir, 'best.optim')
- config_path = os.path.join(logdir, 'best.config')
-
- else:
- raise NotImplementedError()
-
- print("=> Loading checkpoint from '{}'".format(logdir))
- if os.path.exists(model_path):
- model_state = torch.load(model_path)
- optim_state = torch.load(optim_path)
- with open(config_path, 'rb') as handle:
- cfg = pickle.load(handle)
- else:
- return None, None, None
-
- return model_state, optim_state, cfg
-
-
- def save_checkpoint(epoch, model_state, optim_state, logdir):
- last_model = os.path.join(logdir, 'last.model')
- last_optim = os.path.join(logdir, 'last.optim')
- last_config = os.path.join(logdir, 'last.config')
-
- opt = {
- 'epoch': epoch,
- }
- torch.save(model_state, last_model)
- torch.save(optim_state, last_optim)
- with open(last_config, 'wb') as handle:
- pickle.dump(opt, handle, protocol=pickle.HIGHEST_PROTOCOL)
-
-
- def load_linear_checkpoint(logdir, mode='last'):
- if mode == 'last':
- linear_optim_path = os.path.join(logdir, 'last.linear_optim')
- elif mode == 'best':
- linear_optim_path = os.path.join(logdir, 'best.linear_optim')
- else:
- raise NotImplementedError()
-
- print("=> Loading linear optimizer checkpoint from '{}'".format(logdir))
- if os.path.exists(linear_optim_path):
- linear_optim_state = torch.load(linear_optim_path)
- return linear_optim_state
- else:
- return None
-
-
- def save_linear_checkpoint(linear_optim_state, logdir):
- last_linear_optim = os.path.join(logdir, 'last.linear_optim')
- torch.save(linear_optim_state, last_linear_optim)
-
-
- def set_random_seed(seed):
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
-
-
- def normalize(x, dim=1, eps=1e-8):
- return x / (x.norm(dim=dim, keepdim=True) + eps)
-
-
- def make_model_diagrams(probs, labels, n_bins=10):
- """
- outputs - a torch tensor (size n x num_classes) with the outputs from the final linear layer
- - NOT the softmaxes
- labels - a torch tensor (size n) with the labels
- """
- confidences, predictions = probs.max(1)
- accuracies = torch.eq(predictions, labels)
- f, rel_ax = plt.subplots(1, 2, figsize=(4, 2.5))
-
- # Reliability diagram
- bins = torch.linspace(0, 1, n_bins + 1)
- bins[-1] = 1.0001
- width = bins[1] - bins[0]
- bin_indices = [confidences.ge(bin_lower) * confidences.lt(bin_upper) for bin_lower, bin_upper in
- zip(bins[:-1], bins[1:])]
- bin_corrects = [torch.mean(accuracies[bin_index]) for bin_index in bin_indices]
- bin_scores = [torch.mean(confidences[bin_index]) for bin_index in bin_indices]
-
- confs = rel_ax.bar(bins[:-1], bin_corrects.numpy(), width=width)
- gaps = rel_ax.bar(bins[:-1], (bin_scores - bin_corrects).numpy(), bottom=bin_corrects.numpy(), color=[1, 0.7, 0.7],
- alpha=0.5, width=width, hatch='//', edgecolor='r')
- rel_ax.plot([0, 1], [0, 1], '--', color='gray')
- rel_ax.legend([confs, gaps], ['Outputs', 'Gap'], loc='best', fontsize='small')
-
- # Clean up
- rel_ax.set_ylabel('Accuracy')
- rel_ax.set_xlabel('Confidence')
- f.tight_layout()
- return f
-
|