In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete Methode des 3. Platzes der ISBI2019.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main_manual_abl_layerlr.py 7.6KB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import argparse
  2. import os
  3. from collections import defaultdict
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. from sklearn.metrics import roc_auc_score, confusion_matrix, precision_recall_fscore_support, accuracy_score
  8. from tensorboardX import SummaryWriter
  9. from torch.optim.lr_scheduler import StepLR, LambdaLR
  10. from torch.utils.data import DataLoader
  11. from tqdm import tqdm, trange
  12. from dataset import get_dataset, get_tf_train_transform, get_tf_vaild_rot_transform
  13. from model import get_model
  14. from utils import IncrementalAverage, to_device, set_seeds, unique_string, count_parameters
  15. def evaluate(model, valid_loader, class_weights, device):
  16. model.eval()
  17. all_labels = []
  18. all_preds = []
  19. loss_avg = IncrementalAverage()
  20. for img, label in tqdm(valid_loader, leave=False):
  21. img, label = to_device(device, img, label)
  22. bs, nrot, c, h, w = img.size()
  23. with torch.no_grad():
  24. pred = model(img.view(-1, c, h, w))
  25. pred = pred.view(bs, nrot).mean(1)
  26. loss = lossfn(pred, label.to(pred.dtype), class_weights)
  27. all_labels.append(label.cpu())
  28. all_preds.append(pred.cpu())
  29. loss_avg.update(loss.item())
  30. all_labels = torch.cat(all_labels).numpy()
  31. all_preds = torch.cat(all_preds).numpy()
  32. all_preds_binary = all_preds > 0
  33. cm = confusion_matrix(all_labels, all_preds_binary)
  34. auc = roc_auc_score(all_labels, all_preds)
  35. prec, rec, f1, _ = precision_recall_fscore_support(all_labels, all_preds_binary, average='weighted')
  36. return loss_avg.value, cm, auc, prec, rec, f1
  37. def train(model, opt, train_loader, class_weights, device):
  38. model.train()
  39. loss_avg = IncrementalAverage()
  40. for img, label in tqdm(train_loader, leave=False):
  41. img, label = to_device(device, img, label)
  42. pred = model(img)
  43. pred = pred.view(-1)
  44. loss = lossfn(pred, label.to(pred.dtype), class_weights)
  45. loss_avg.update(loss.item())
  46. opt.zero_grad()
  47. loss.backward()
  48. opt.step()
  49. return loss_avg.value
  50. def lossfn(prediction, target, class_weights):
  51. pos_weight = (class_weights[0] / class_weights[1]).expand(len(target))
  52. return F.binary_cross_entropy_with_logits(prediction, target, pos_weight=pos_weight)
  53. def schedule(epoch):
  54. if epoch < 2:
  55. ub = 1
  56. elif epoch < 4:
  57. ub = 0.1
  58. else:
  59. ub = 0.01
  60. return ub
  61. def train_validate(args):
  62. model = get_model().to(args.device)
  63. print("Model parameters:", count_parameters(model))
  64. trainset, validset, validset_subjects, class_weights = get_dataset(args.dataroot,
  65. tf_train=get_tf_train_transform(args.res),
  66. tf_valid=get_tf_vaild_rot_transform(args.res))
  67. class_weights = class_weights.to(args.device)
  68. print(f"Trainset length: {len(trainset)}")
  69. print(f"Validset length: {len(validset)}")
  70. print(f"class_weights = {class_weights}")
  71. train_loader = DataLoader(trainset, batch_size=args.batch_size, num_workers=6, shuffle=True, drop_last=True)
  72. valid_loader = DataLoader(validset, batch_size=args.batch_size, num_workers=6, shuffle=False)
  73. opt = torch.optim.Adam([
  74. {'params': model.paramgroup01(), 'lr': args.lr},
  75. {'params': model.paramgroup234(), 'lr': args.lr},
  76. {'params': model.parameters_classifier(), 'lr': args.lr},
  77. ])
  78. scheduler = LambdaLR(opt, lr_lambda=[lambda e: schedule(e),
  79. lambda e: schedule(e),
  80. lambda e: schedule(e)])
  81. summarywriter = SummaryWriter(args.out)
  82. recorded_data = defaultdict(list)
  83. def logged_eval(e):
  84. valid_loss, cm, auc, prec, rec, f1 = evaluate(model, valid_loader, class_weights, args.device)
  85. # Derive some accuracy metrics from confusion matrix
  86. tn, fp, fn, tp = cm.ravel()
  87. acc = (tp + tn) / cm.sum()
  88. acc_hem = tn / (tn + fp)
  89. acc_all = tp / (tp + fn)
  90. print(f"epoch={e} f1={f1:.4f}")
  91. summarywriter.add_scalar('loss/train', train_loss, e)
  92. summarywriter.add_scalar('loss/valid', valid_loss, e)
  93. summarywriter.add_scalar('cm/tn', tn, e)
  94. summarywriter.add_scalar('cm/fp', fp, e)
  95. summarywriter.add_scalar('cm/fn', fn, e)
  96. summarywriter.add_scalar('cm/tp', tp, e)
  97. summarywriter.add_scalar('metrics/precision', prec, e)
  98. summarywriter.add_scalar('metrics/recall', rec, e)
  99. summarywriter.add_scalar('metrics/f1', f1, e)
  100. summarywriter.add_scalar('metrics/auc', auc, e)
  101. summarywriter.add_scalar('acc/acc', acc, e)
  102. summarywriter.add_scalar('acc/hem', acc_hem, e)
  103. summarywriter.add_scalar('acc/all', acc_all, e)
  104. recorded_data['loss_train'].append(train_loss)
  105. recorded_data['loss_valid'].append(valid_loss)
  106. recorded_data['tn'].append(tn)
  107. recorded_data['tn'].append(tn)
  108. recorded_data['fp'].append(fp)
  109. recorded_data['fn'].append(fn)
  110. recorded_data['tp'].append(tp)
  111. recorded_data['precision'].append(prec)
  112. recorded_data['recall'].append(rec)
  113. recorded_data['f1'].append(f1)
  114. recorded_data['auc'].append(auc)
  115. recorded_data['acc'].append(acc)
  116. recorded_data['acc_hem'].append(acc_hem)
  117. recorded_data['acc_all'].append(acc_all)
  118. np.savez(f'{args.out}/results', **recorded_data)
  119. model = torch.nn.DataParallel(model)
  120. train_loss = np.nan
  121. logged_eval(0)
  122. for e in trange(args.epochs, desc='Epoch'):
  123. scheduler.step(e)
  124. train_loss = train(model, opt, train_loader, class_weights, args.device)
  125. logged_eval(e + 1)
  126. summarywriter.close()
  127. subj_acc = evaluate_subj_acc(model, validset, validset_subjects, args.device)
  128. np.savez(f'{args.out}/subj_acc', **subj_acc)
  129. def evaluate_subj_acc(model, dataset, subjects, device):
  130. model.eval()
  131. subj_pred = defaultdict(list)
  132. subj_label = defaultdict(list)
  133. dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False)
  134. for (img, cls), subj in tqdm(zip(dataloader, subjects), total=len(subjects), leave=False):
  135. img, cls = to_device(device, img, cls)
  136. bs, nrot, c, h, w = img.size()
  137. with torch.no_grad():
  138. cls_hat = model(img.view(-1, c, h, w))
  139. cls_hat = cls_hat.view(bs, nrot).mean(1)
  140. subj_label[subj].append(cls.cpu())
  141. subj_pred[subj].append(cls_hat.cpu())
  142. for k in subj_label:
  143. subj_label[k] = torch.cat(subj_label[k]).numpy()
  144. subj_pred[k] = torch.cat(subj_pred[k]).numpy() > 0
  145. subj_acc = {}
  146. for k in subj_label:
  147. subj_acc[k] = accuracy_score(subj_label[k], subj_pred[k])
  148. return subj_acc
  149. def parse_args():
  150. parser = argparse.ArgumentParser()
  151. parser.add_argument('--dataroot', default='data', help='path to dataset')
  152. parser.add_argument('--lr', type=float, default=1e-4)
  153. parser.add_argument('--batch-size', type=int, default=16)
  154. parser.add_argument('--epochs', type=int, default=6)
  155. parser.add_argument('--seed', default=1, type=int, help='random seed')
  156. parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu')
  157. parser.add_argument('--out', default='results', help='output folder')
  158. parser.add_argument('--res', type=int, default='450', help='Desired input resolution')
  159. args = parser.parse_args()
  160. args.out = os.path.join(args.out, unique_string())
  161. return args
  162. if __name__ == '__main__':
  163. args = parse_args()
  164. print(args)
  165. os.makedirs(args.out, exist_ok=True)
  166. set_seeds(args.seed)
  167. torch.backends.cudnn.benchmark = True
  168. train_validate(args)