import argparse import os import zipfile from os.path import join import torch from PIL import Image from torch.utils.data import DataLoader, Dataset from tqdm import tqdm import numpy as np from model import get_model from dataset import get_tf_vaild_rot_transform from sklearn import metrics import matplotlib.pyplot as plt import csv from sklearn.metrics import roc_curve, roc_auc_score class OrderedImages(Dataset): def __init__(self, root, transform): super().__init__() self.root = root self.transform = transform def __len__(self): return 1867 def __getitem__(self, index): img = Image.open(os.path.join(self.root, f'{index + 1}.bmp'))#.convert('RGB') return self.transform(img) VALIDATION_ALL = 1219 VALIDATION_HEM = 648 parser = argparse.ArgumentParser() parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--modelroot', default='results/20190313T101236Z.LGJL', help='path to model') parser.add_argument('--dataroot', default='data/phase3', help='path to dataset') parser.add_argument('--res', type=int, default='450', help='Desired input resolution') args = parser.parse_args() dataset = OrderedImages(args.dataroot, get_tf_vaild_rot_transform(args.res)) print(f"Loading model") model = get_model().to('cuda:0') model = torch.nn.DataParallel(model) model.load_state_dict(torch.load(join(args.modelroot, 'model.pt'))) model.eval() dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=6) print("Classifying") all_labels = [] for x in tqdm(dataloader, total=len(dataset) // args.batch_size): x = x.to('cuda:0') bs, nrot, c, h, w = x.size() with torch.no_grad(): y = model(x.view(-1, c, h, w)) y = y.view(bs, nrot).mean(1) labels = y > 0 all_labels.append(labels) all_labels = torch.cat(all_labels) print("Positive:", all_labels.sum().item()) print("Negative:", len(all_labels) - all_labels.sum().item()) file_w = open(r'/home/feoktistovar67431/data/resources/phase2_labels.csv', "r") true_labels = [] reader = csv.reader(file_w, delimiter=',') for row in reader: true_labels.append(row) print(f'AUC: {roc_auc_score(true_labels, all_labels.cpu())}') # Zeige Flaeche unter der Kurve an #print("Accuracy", metrics.accuracy_score(y_test, y_pred)) #import matplotlib.pyplot as plt #import numpy as np #x = # false_positive_rate #y = # true_positive_rate # This is the ROC curve #plt.plot(x,y) #plt.show() # This is the AUC #auc = np.trapz(y,x) csv_path = join(args.modelroot, 'submission.csv') zip_path = join(args.modelroot, 'submission.zip') np.savetxt(csv_path, all_labels.cpu().numpy(), '%d') with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: zipf.write(csv_path, 'isbi_valid.predict')