In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete Methode des 3. Platzes der ISBI2019.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main_manual_abl_testrot.py 7.5KB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import argparse
  2. import os
  3. from collections import defaultdict
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. from sklearn.metrics import roc_auc_score, confusion_matrix, precision_recall_fscore_support, accuracy_score
  8. from tensorboardX import SummaryWriter
  9. from torch.optim.lr_scheduler import StepLR, LambdaLR
  10. from torch.utils.data import DataLoader
  11. from tqdm import tqdm, trange
  12. from dataset import get_dataset, get_tf_valid_norot_transform, get_tf_train_transform
  13. from model import get_model
  14. from utils import IncrementalAverage, to_device, set_seeds, unique_string, count_parameters
  15. def evaluate(model, valid_loader, class_weights, device):
  16. model.eval()
  17. all_labels = []
  18. all_preds = []
  19. loss_avg = IncrementalAverage()
  20. for img, label in tqdm(valid_loader, leave=False):
  21. img, label = to_device(device, img, label)
  22. with torch.no_grad():
  23. pred = model(img).view(-1)
  24. loss = lossfn(pred, label.to(pred.dtype), class_weights)
  25. all_labels.append(label.cpu())
  26. all_preds.append(pred.cpu())
  27. loss_avg.update(loss.item())
  28. all_labels = torch.cat(all_labels).numpy()
  29. all_preds = torch.cat(all_preds).numpy()
  30. all_preds_binary = all_preds > 0
  31. cm = confusion_matrix(all_labels, all_preds_binary)
  32. auc = roc_auc_score(all_labels, all_preds)
  33. prec, rec, f1, _ = precision_recall_fscore_support(all_labels, all_preds_binary, average='weighted')
  34. return loss_avg.value, cm, auc, prec, rec, f1
  35. def train(model, opt, train_loader, class_weights, device):
  36. model.train()
  37. loss_avg = IncrementalAverage()
  38. for img, label in tqdm(train_loader, leave=False):
  39. img, label = to_device(device, img, label)
  40. pred = model(img)
  41. pred = pred.view(-1)
  42. loss = lossfn(pred, label.to(pred.dtype), class_weights)
  43. loss_avg.update(loss.item())
  44. opt.zero_grad()
  45. loss.backward()
  46. opt.step()
  47. return loss_avg.value
  48. def lossfn(prediction, target, class_weights):
  49. pos_weight = (class_weights[0] / class_weights[1]).expand(len(target))
  50. return F.binary_cross_entropy_with_logits(prediction, target, pos_weight=pos_weight)
  51. def schedule(epoch):
  52. if epoch < 2:
  53. ub = 1
  54. elif epoch < 4:
  55. ub = 0.1
  56. else:
  57. ub = 0.01
  58. return ub
  59. def train_validate(args):
  60. model = get_model().to(args.device)
  61. print("Model parameters:", count_parameters(model))
  62. trainset, validset, validset_subjects, class_weights = get_dataset(args.dataroot,
  63. tf_valid=get_tf_valid_norot_transform(args.res),
  64. tf_train=get_tf_train_transform(args.res))
  65. class_weights = class_weights.to(args.device)
  66. print(f"Trainset length: {len(trainset)}")
  67. print(f"Validset length: {len(validset)}")
  68. print(f"class_weights = {class_weights}")
  69. train_loader = DataLoader(trainset, batch_size=args.batch_size, num_workers=6, shuffle=True, drop_last=True)
  70. valid_loader = DataLoader(validset, batch_size=args.batch_size, num_workers=6, shuffle=False)
  71. opt = torch.optim.Adam([
  72. {'params': model.paramgroup01(), 'lr': 1e-6},
  73. {'params': model.paramgroup234(), 'lr': 1e-4},
  74. {'params': model.parameters_classifier(), 'lr': 1e-2},
  75. ])
  76. scheduler = LambdaLR(opt, lr_lambda=[lambda e: schedule(e),
  77. lambda e: schedule(e),
  78. lambda e: schedule(e)])
  79. summarywriter = SummaryWriter(args.out)
  80. recorded_data = defaultdict(list)
  81. def logged_eval(e):
  82. valid_loss, cm, auc, prec, rec, f1 = evaluate(model, valid_loader, class_weights, args.device)
  83. # Derive some accuracy metrics from confusion matrix
  84. tn, fp, fn, tp = cm.ravel()
  85. acc = (tp + tn) / cm.sum()
  86. acc_hem = tn / (tn + fp)
  87. acc_all = tp / (tp + fn)
  88. print(f"epoch={e} f1={f1:.4f}")
  89. summarywriter.add_scalar('loss/train', train_loss, e)
  90. summarywriter.add_scalar('loss/valid', valid_loss, e)
  91. summarywriter.add_scalar('cm/tn', tn, e)
  92. summarywriter.add_scalar('cm/fp', fp, e)
  93. summarywriter.add_scalar('cm/fn', fn, e)
  94. summarywriter.add_scalar('cm/tp', tp, e)
  95. summarywriter.add_scalar('metrics/precision', prec, e)
  96. summarywriter.add_scalar('metrics/recall', rec, e)
  97. summarywriter.add_scalar('metrics/f1', f1, e)
  98. summarywriter.add_scalar('metrics/auc', auc, e)
  99. summarywriter.add_scalar('acc/acc', acc, e)
  100. summarywriter.add_scalar('acc/hem', acc_hem, e)
  101. summarywriter.add_scalar('acc/all', acc_all, e)
  102. recorded_data['loss_train'].append(train_loss)
  103. recorded_data['loss_valid'].append(valid_loss)
  104. recorded_data['tn'].append(tn)
  105. recorded_data['tn'].append(tn)
  106. recorded_data['fp'].append(fp)
  107. recorded_data['fn'].append(fn)
  108. recorded_data['tp'].append(tp)
  109. recorded_data['precision'].append(prec)
  110. recorded_data['recall'].append(rec)
  111. recorded_data['f1'].append(f1)
  112. recorded_data['auc'].append(auc)
  113. recorded_data['acc'].append(acc)
  114. recorded_data['acc_hem'].append(acc_hem)
  115. recorded_data['acc_all'].append(acc_all)
  116. np.savez(f'{args.out}/results', **recorded_data)
  117. model = torch.nn.DataParallel(model)
  118. train_loss = np.nan
  119. logged_eval(0)
  120. for e in trange(args.epochs, desc='Epoch'):
  121. scheduler.step(e)
  122. train_loss = train(model, opt, train_loader, class_weights, args.device)
  123. logged_eval(e + 1)
  124. torch.save(model.state_dict(), f'{args.out}/model.pt')
  125. summarywriter.close()
  126. subj_acc = evaluate_subj_acc(model, validset, validset_subjects, args.device)
  127. np.savez(f'{args.out}/subj_acc', **subj_acc)
  128. def evaluate_subj_acc(model, dataset, subjects, device):
  129. model.eval()
  130. subj_pred = defaultdict(list)
  131. subj_label = defaultdict(list)
  132. dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False)
  133. for (img, cls), subj in tqdm(zip(dataloader, subjects), total=len(subjects), leave=False):
  134. img, cls = to_device(device, img, cls)
  135. bs, nrot, c, h, w = img.size()
  136. with torch.no_grad():
  137. cls_hat = model(img.view(-1, c, h, w))
  138. cls_hat = cls_hat.view(bs, nrot).mean(1)
  139. subj_label[subj].append(cls.cpu())
  140. subj_pred[subj].append(cls_hat.cpu())
  141. for k in subj_label:
  142. subj_label[k] = torch.cat(subj_label[k]).numpy()
  143. subj_pred[k] = torch.cat(subj_pred[k]).numpy() > 0
  144. subj_acc = {}
  145. for k in subj_label:
  146. subj_acc[k] = accuracy_score(subj_label[k], subj_pred[k])
  147. return subj_acc
  148. def parse_args():
  149. parser = argparse.ArgumentParser()
  150. parser.add_argument('--dataroot', default='data', help='path to dataset')
  151. parser.add_argument('--batch-size', type=int, default=16)
  152. parser.add_argument('--epochs', type=int, default=6)
  153. parser.add_argument('--seed', default=1, type=int, help='random seed')
  154. parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu')
  155. parser.add_argument('--out', default='results', help='output folder')
  156. parser.add_argument('--res', type=int, default='450', help='Desired input resolution')
  157. args = parser.parse_args()
  158. args.out = os.path.join(args.out, unique_string())
  159. return args
  160. if __name__ == '__main__':
  161. args = parse_args()
  162. print(args)
  163. os.makedirs(args.out, exist_ok=True)
  164. set_seeds(args.seed)
  165. torch.backends.cudnn.benchmark = True
  166. train_validate(args)