93 lines
2.7 KiB
Python
93 lines
2.7 KiB
Python
import argparse
|
|
import os
|
|
import zipfile
|
|
from os.path import join
|
|
|
|
import torch
|
|
from PIL import Image
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
|
|
from model import get_model
|
|
from dataset import get_tf_vaild_rot_transform
|
|
|
|
from sklearn import metrics
|
|
import matplotlib.pyplot as plt
|
|
import csv
|
|
from sklearn.metrics import roc_curve, roc_auc_score
|
|
|
|
class OrderedImages(Dataset):
|
|
def __init__(self, root, transform):
|
|
super().__init__()
|
|
self.root = root
|
|
self.transform = transform
|
|
|
|
def __len__(self):
|
|
return 1867
|
|
|
|
def __getitem__(self, index):
|
|
img = Image.open(os.path.join(self.root, f'{index + 1}.bmp'))#.convert('RGB')
|
|
return self.transform(img)
|
|
|
|
VALIDATION_ALL = 1219
|
|
VALIDATION_HEM = 648
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--batch-size', type=int, default=64)
|
|
parser.add_argument('--modelroot', default='results/20190313T101236Z.LGJL', help='path to model')
|
|
parser.add_argument('--dataroot', default='data/phase3', help='path to dataset')
|
|
parser.add_argument('--res', type=int, default='450', help='Desired input resolution')
|
|
args = parser.parse_args()
|
|
|
|
dataset = OrderedImages(args.dataroot, get_tf_vaild_rot_transform(args.res))
|
|
|
|
print(f"Loading model")
|
|
model = get_model().to('cuda:0')
|
|
model = torch.nn.DataParallel(model)
|
|
model.load_state_dict(torch.load(join(args.modelroot, 'model.pt')))
|
|
model.eval()
|
|
|
|
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=6)
|
|
|
|
print("Classifying")
|
|
all_labels = []
|
|
for x in tqdm(dataloader, total=len(dataset) // args.batch_size):
|
|
x = x.to('cuda:0')
|
|
bs, nrot, c, h, w = x.size()
|
|
with torch.no_grad():
|
|
y = model(x.view(-1, c, h, w))
|
|
y = y.view(bs, nrot).mean(1)
|
|
labels = y > 0
|
|
all_labels.append(labels)
|
|
|
|
all_labels = torch.cat(all_labels)
|
|
print("Positive:", all_labels.sum().item())
|
|
print("Negative:", len(all_labels) - all_labels.sum().item())
|
|
|
|
file_w = open(r'/home/feoktistovar67431/data/resources/phase2_labels.csv', "r")
|
|
true_labels = []
|
|
reader = csv.reader(file_w, delimiter=',')
|
|
for row in reader:
|
|
true_labels.append(row)
|
|
|
|
print(f'AUC: {roc_auc_score(true_labels, all_labels.cpu())}') # Zeige Flaeche unter der Kurve an
|
|
|
|
|
|
#print("Accuracy", metrics.accuracy_score(y_test, y_pred))
|
|
#import matplotlib.pyplot as plt
|
|
#import numpy as np
|
|
#x = # false_positive_rate
|
|
#y = # true_positive_rate
|
|
# This is the ROC curve
|
|
#plt.plot(x,y)
|
|
#plt.show()
|
|
# This is the AUC
|
|
#auc = np.trapz(y,x)
|
|
|
|
csv_path = join(args.modelroot, 'submission.csv')
|
|
zip_path = join(args.modelroot, 'submission.zip')
|
|
np.savetxt(csv_path, all_labels.cpu().numpy(), '%d')
|
|
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
|
zipf.write(csv_path, 'isbi_valid.predict')
|