In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete Methode des 3. Platzes der ISBI2019.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

submission.py 2.7KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import argparse
  2. import os
  3. import zipfile
  4. from os.path import join
  5. import torch
  6. from PIL import Image
  7. from torch.utils.data import DataLoader, Dataset
  8. from tqdm import tqdm
  9. import numpy as np
  10. from model import get_model
  11. from dataset import get_tf_vaild_rot_transform
  12. from sklearn import metrics
  13. import matplotlib.pyplot as plt
  14. import csv
  15. from sklearn.metrics import roc_curve, roc_auc_score
  16. class OrderedImages(Dataset):
  17. def __init__(self, root, transform):
  18. super().__init__()
  19. self.root = root
  20. self.transform = transform
  21. def __len__(self):
  22. return 1867
  23. def __getitem__(self, index):
  24. img = Image.open(os.path.join(self.root, f'{index + 1}.bmp'))#.convert('RGB')
  25. return self.transform(img)
  26. VALIDATION_ALL = 1219
  27. VALIDATION_HEM = 648
  28. parser = argparse.ArgumentParser()
  29. parser.add_argument('--batch-size', type=int, default=64)
  30. parser.add_argument('--modelroot', default='results/20190313T101236Z.LGJL', help='path to model')
  31. parser.add_argument('--dataroot', default='data/phase3', help='path to dataset')
  32. parser.add_argument('--res', type=int, default='450', help='Desired input resolution')
  33. args = parser.parse_args()
  34. dataset = OrderedImages(args.dataroot, get_tf_vaild_rot_transform(args.res))
  35. print(f"Loading model")
  36. model = get_model().to('cuda:0')
  37. model = torch.nn.DataParallel(model)
  38. model.load_state_dict(torch.load(join(args.modelroot, 'model.pt')))
  39. model.eval()
  40. dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=6)
  41. print("Classifying")
  42. all_labels = []
  43. for x in tqdm(dataloader, total=len(dataset) // args.batch_size):
  44. x = x.to('cuda:0')
  45. bs, nrot, c, h, w = x.size()
  46. with torch.no_grad():
  47. y = model(x.view(-1, c, h, w))
  48. y = y.view(bs, nrot).mean(1)
  49. labels = y > 0
  50. all_labels.append(labels)
  51. all_labels = torch.cat(all_labels)
  52. print("Positive:", all_labels.sum().item())
  53. print("Negative:", len(all_labels) - all_labels.sum().item())
  54. file_w = open(r'/home/feoktistovar67431/data/resources/phase2_labels.csv', "r")
  55. true_labels = []
  56. reader = csv.reader(file_w, delimiter=',')
  57. for row in reader:
  58. true_labels.append(row)
  59. print(f'AUC: {roc_auc_score(true_labels, all_labels.cpu())}') # Zeige Flaeche unter der Kurve an
  60. #print("Accuracy", metrics.accuracy_score(y_test, y_pred))
  61. #import matplotlib.pyplot as plt
  62. #import numpy as np
  63. #x = # false_positive_rate
  64. #y = # true_positive_rate
  65. # This is the ROC curve
  66. #plt.plot(x,y)
  67. #plt.show()
  68. # This is the AUC
  69. #auc = np.trapz(y,x)
  70. csv_path = join(args.modelroot, 'submission.csv')
  71. zip_path = join(args.modelroot, 'submission.zip')
  72. np.savetxt(csv_path, all_labels.cpu().numpy(), '%d')
  73. with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
  74. zipf.write(csv_path, 'isbi_valid.predict')