In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete Methode des 3. Platzes der ISBI2019.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main_manual.py 9.1KB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import argparse
  2. import os
  3. from collections import defaultdict
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. from sklearn.metrics import roc_auc_score, confusion_matrix, precision_recall_fscore_support, accuracy_score
  8. from tensorboardX import SummaryWriter
  9. from torch.optim.lr_scheduler import StepLR, LambdaLR
  10. from torch.utils.data import DataLoader
  11. from tqdm import tqdm, trange
  12. from dataset import get_dataset, get_tf_train_transform, get_tf_vaild_rot_transform
  13. from model import get_model
  14. from utils import IncrementalAverage, to_device, set_seeds, unique_string, count_parameters
  15. def evaluate(model, valid_loader, class_weights, device):
  16. model.eval()
  17. all_labels = []
  18. all_preds = []
  19. loss_avg = IncrementalAverage()
  20. for img, label in tqdm(valid_loader, leave=False):
  21. img, label = to_device(device, img, label)
  22. bs, nrot, c, h, w = img.size()
  23. with torch.no_grad():
  24. pred = model(img.view(-1, c, h, w))
  25. pred = pred.view(bs, nrot).mean(1)
  26. loss = lossfn(pred, label.to(pred.dtype), class_weights)
  27. all_labels.append(label.cpu())
  28. all_preds.append(pred.cpu())
  29. loss_avg.update(loss.item())
  30. all_labels = torch.cat(all_labels).numpy()
  31. all_preds = torch.cat(all_preds).numpy()
  32. all_preds_binary = all_preds > 0
  33. cm = confusion_matrix(all_labels, all_preds_binary)
  34. auc = roc_auc_score(all_labels, all_preds)
  35. prec, rec, f1, _ = precision_recall_fscore_support(all_labels, all_preds_binary, average='weighted')
  36. return loss_avg.value, cm, auc, prec, rec, f1
  37. def train(model, opt, train_loader, class_weights, device):
  38. model.train()
  39. loss_avg = IncrementalAverage()
  40. for img, label in tqdm(train_loader, leave=False):
  41. img, label = to_device(device, img, label)
  42. pred = model(img)
  43. pred = pred.view(-1)
  44. loss = lossfn(pred, label.to(pred.dtype), class_weights)
  45. loss_avg.update(loss.item())
  46. opt.zero_grad()
  47. loss.backward()
  48. opt.step()
  49. return loss_avg.value
  50. def lossfn(prediction, target, class_weights):
  51. pos_weight = (class_weights[0] / class_weights[1]).expand(len(target))
  52. return F.binary_cross_entropy_with_logits(prediction, target, pos_weight=pos_weight)
  53. def schedule(epoch):
  54. if epoch < 2:
  55. ub = 1
  56. elif epoch < 4:
  57. ub = 0.1
  58. else:
  59. ub = 0.01
  60. return ub
  61. def train_validate(args):
  62. model = get_model().to(args.device)
  63. print("Model parameters:", count_parameters(model))
  64. trainset, validset, validset_subjects, class_weights = get_dataset(args.dataroot,
  65. tf_train=get_tf_train_transform(args.res),
  66. tf_valid=get_tf_vaild_rot_transform(args.res))
  67. class_weights = class_weights.to(args.device)
  68. print(f"Trainset length: {len(trainset)}")
  69. print(f"Validset length: {len(validset)}")
  70. print(f"class_weights = {class_weights}")
  71. train_loader = DataLoader(trainset, batch_size=args.batch_size, num_workers=6, shuffle=True, drop_last=True)
  72. valid_loader = DataLoader(validset, batch_size=args.batch_size, num_workers=6, shuffle=False)
  73. opt = torch.optim.Adam([
  74. {'params': model.paramgroup01(), 'lr': 1e-6},
  75. {'params': model.paramgroup234(), 'lr': 1e-4},
  76. {'params': model.parameters_classifier(), 'lr': 1e-2},
  77. ])
  78. scheduler = LambdaLR(opt, lr_lambda=[lambda e: schedule(e),
  79. lambda e: schedule(e),
  80. lambda e: schedule(e)])
  81. summarywriter = SummaryWriter(args.out)
  82. recorded_data = defaultdict(list)
  83. def logged_eval(e):
  84. valid_loss, cm, auc, prec, rec, f1 = evaluate(model, valid_loader, class_weights, args.device)
  85. # Derive some accuracy metrics from confusion matrix
  86. tn, fp, fn, tp = cm.ravel()
  87. acc = (tp + tn) / cm.sum()
  88. acc_hem = tn / (tn + fp)
  89. acc_all = tp / (tp + fn)
  90. print(f"epoch={e} f1={f1:.4f}")
  91. summarywriter.add_scalar('loss/train', train_loss, e)
  92. summarywriter.add_scalar('loss/valid', valid_loss, e)
  93. summarywriter.add_scalar('cm/tn', tn, e)
  94. summarywriter.add_scalar('cm/fp', fp, e)
  95. summarywriter.add_scalar('cm/fn', fn, e)
  96. summarywriter.add_scalar('cm/tp', tp, e)
  97. summarywriter.add_scalar('metrics/precision', prec, e)
  98. summarywriter.add_scalar('metrics/recall', rec, e)
  99. summarywriter.add_scalar('metrics/f1', f1, e)
  100. summarywriter.add_scalar('metrics/auc', auc, e)
  101. summarywriter.add_scalar('acc/acc', acc, e)
  102. summarywriter.add_scalar('acc/hem', acc_hem, e)
  103. summarywriter.add_scalar('acc/all', acc_all, e)
  104. recorded_data['loss_train'].append(train_loss)
  105. recorded_data['loss_valid'].append(valid_loss)
  106. recorded_data['tn'].append(tn)
  107. recorded_data['tn'].append(tn)
  108. recorded_data['fp'].append(fp)
  109. recorded_data['fn'].append(fn)
  110. recorded_data['tp'].append(tp)
  111. recorded_data['precision'].append(prec)
  112. recorded_data['recall'].append(rec)
  113. recorded_data['f1'].append(f1)
  114. recorded_data['auc'].append(auc)
  115. recorded_data['acc'].append(acc)
  116. recorded_data['acc_hem'].append(acc_hem)
  117. recorded_data['acc_all'].append(acc_all)
  118. np.savez(f'{args.out}/results', **recorded_data)
  119. return f1
  120. model = torch.nn.DataParallel(model)
  121. train_loss = np.nan
  122. best_val_f1 = logged_eval(0)
  123. for e in trange(args.epochs, desc='Epoch'):
  124. scheduler.step(e)
  125. train_loss = train(model, opt, train_loader, class_weights, args.device)
  126. val_f1 = logged_eval(e + 1)
  127. if val_f1 > best_val_f1:
  128. print(f"New best model at {val_f1:.6f}")
  129. torch.save(model.state_dict(), f'{args.out}/model.pt')
  130. best_val_f1 = val_f1
  131. summarywriter.close()
  132. subj_acc = evaluate_subj_acc(model, validset, validset_subjects, args.device)
  133. np.savez(f'{args.out}/subj_acc', **subj_acc)
  134. def evaluate_subj_acc(model, dataset, subjects, device):
  135. model.eval()
  136. subj_pred = defaultdict(list)
  137. subj_label = defaultdict(list)
  138. dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False)
  139. for (img, cls), subj in tqdm(zip(dataloader, subjects), total=len(subjects), leave=False):
  140. img, cls = to_device(device, img, cls)
  141. bs, nrot, c, h, w = img.size()
  142. with torch.no_grad():
  143. cls_hat = model(img.view(-1, c, h, w))
  144. cls_hat = cls_hat.view(bs, nrot).mean(1)
  145. subj_label[subj].append(cls.cpu())
  146. subj_pred[subj].append(cls_hat.cpu())
  147. for k in subj_label:
  148. subj_label[k] = torch.cat(subj_label[k]).numpy()
  149. subj_pred[k] = torch.cat(subj_pred[k]).numpy() > 0
  150. subj_acc = {}
  151. for k in subj_label:
  152. subj_acc[k] = accuracy_score(subj_label[k], subj_pred[k])
  153. return subj_acc
  154. def train_test(args):
  155. model = get_model().to(args.device)
  156. print("Model parameters:", count_parameters(model))
  157. trainset, class_weights = get_dataset(args.dataroot, folds_train=(0, 1, 2, 3),
  158. folds_valid=None,
  159. tf_train=get_tf_train_transform(args.res),
  160. tf_valid=get_tf_vaild_rot_transform(args.res))
  161. class_weights = class_weights.to(args.device)
  162. print(f"Trainset length: {len(trainset)}")
  163. print(f"class_weights = {class_weights}")
  164. train_loader = DataLoader(trainset, batch_size=args.batch_size, num_workers=6, shuffle=True, drop_last=True)
  165. opt = torch.optim.Adam([
  166. {'params': model.paramgroup01(), 'lr': 1e-6},
  167. {'params': model.paramgroup234(), 'lr': 1e-4},
  168. {'params': model.parameters_classifier(), 'lr': 1e-2},
  169. ])
  170. scheduler = LambdaLR(opt, lr_lambda=[lambda e: schedule(e),
  171. lambda e: schedule(e),
  172. lambda e: schedule(e)])
  173. model = torch.nn.DataParallel(model)
  174. for e in trange(args.epochs, desc='Epoch'):
  175. scheduler.step(e)
  176. train(model, opt, train_loader, class_weights, args.device)
  177. torch.save(model.state_dict(), f'{args.out}/model.pt')
  178. def parse_args():
  179. parser = argparse.ArgumentParser()
  180. parser.add_argument('--dataroot', default='data', help='path to dataset')
  181. parser.add_argument('--batch-size', type=int, default=16)
  182. parser.add_argument('--epochs', type=int, default=6)
  183. parser.add_argument('--seed', default=1, type=int, help='random seed')
  184. parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu')
  185. parser.add_argument('--out', default='results', help='output folder')
  186. parser.add_argument('--res', type=int, default='450', help='Desired input resolution')
  187. args = parser.parse_args()
  188. args.out = os.path.join(args.out, unique_string())
  189. return args
  190. if __name__ == '__main__':
  191. args = parse_args()
  192. print(args)
  193. os.makedirs(args.out, exist_ok=True)
  194. set_seeds(args.seed)
  195. torch.backends.cudnn.benchmark = True
  196. train_validate(args)