1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- import argparse
- import os
- import zipfile
- from os.path import join
-
- import torch
- from PIL import Image
- from torch.utils.data import DataLoader, Dataset
- from tqdm import tqdm
- import numpy as np
-
- from model import get_model
- from dataset import get_tf_vaild_rot_transform
-
- from sklearn import metrics
- import matplotlib.pyplot as plt
- import csv
- from sklearn.metrics import roc_curve, roc_auc_score
-
- class OrderedImages(Dataset):
- def __init__(self, root, transform):
- super().__init__()
- self.root = root
- self.transform = transform
-
- def __len__(self):
- return 1867
-
- def __getitem__(self, index):
- img = Image.open(os.path.join(self.root, f'{index + 1}.bmp'))#.convert('RGB')
- return self.transform(img)
-
- VALIDATION_ALL = 1219
- VALIDATION_HEM = 648
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--batch-size', type=int, default=64)
- parser.add_argument('--modelroot', default='results/20190313T101236Z.LGJL', help='path to model')
- parser.add_argument('--dataroot', default='data/phase3', help='path to dataset')
- parser.add_argument('--res', type=int, default='450', help='Desired input resolution')
- args = parser.parse_args()
-
- dataset = OrderedImages(args.dataroot, get_tf_vaild_rot_transform(args.res))
-
- print(f"Loading model")
- model = get_model().to('cuda:0')
- model = torch.nn.DataParallel(model)
- model.load_state_dict(torch.load(join(args.modelroot, 'model.pt')))
- model.eval()
-
- dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=6)
-
- print("Classifying")
- all_labels = []
- for x in tqdm(dataloader, total=len(dataset) // args.batch_size):
- x = x.to('cuda:0')
- bs, nrot, c, h, w = x.size()
- with torch.no_grad():
- y = model(x.view(-1, c, h, w))
- y = y.view(bs, nrot).mean(1)
- labels = y > 0
- all_labels.append(labels)
-
- all_labels = torch.cat(all_labels)
- print("Positive:", all_labels.sum().item())
- print("Negative:", len(all_labels) - all_labels.sum().item())
-
- file_w = open(r'/home/feoktistovar67431/data/resources/phase2_labels.csv', "r")
- true_labels = []
- reader = csv.reader(file_w, delimiter=',')
- for row in reader:
- true_labels.append(row)
-
- print(f'AUC: {roc_auc_score(true_labels, all_labels.cpu())}') # Zeige Flaeche unter der Kurve an
-
-
- #print("Accuracy", metrics.accuracy_score(y_test, y_pred))
- #import matplotlib.pyplot as plt
- #import numpy as np
- #x = # false_positive_rate
- #y = # true_positive_rate
- # This is the ROC curve
- #plt.plot(x,y)
- #plt.show()
- # This is the AUC
- #auc = np.trapz(y,x)
-
- csv_path = join(args.modelroot, 'submission.csv')
- zip_path = join(args.modelroot, 'submission.zip')
- np.savetxt(csv_path, all_labels.cpu().numpy(), '%d')
- with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
- zipf.write(csv_path, 'isbi_valid.predict')
|