commit c7a6997e9f69441f5de2dbb7ebbb8b2e49240530 Author: Artur Feoktistov Date: Fri Apr 29 19:33:43 2022 +0200 init diff --git a/.ipynb_checkpoints/run-checkpoint.ipynb b/.ipynb_checkpoints/run-checkpoint.ipynb new file mode 100644 index 0000000..45a419e --- /dev/null +++ b/.ipynb_checkpoints/run-checkpoint.ipynb @@ -0,0 +1,4395 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "cd8aaf96", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install pandas tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26bd5e25", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "b753e6b8", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model\n", + "Classifying\n", + "59it [00:41, 1.43it/s] \n", + "Positive: 1234\n", + "Negative: 633\n", + "AUC: 0.8797024225483345\n" + ] + } + ], + "source": [ + "!python3 submission.py --modelroot \"/home/feoktistovar67431/isbi2019cancer-master/results/20220216T154306Z.AZHL\" --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC/phase2\" --batch-size 32" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3246460b", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC\n", + "# res : 32\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 32" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a953a39", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC\n", + "# res : 128\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 128" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12c15b33", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC\n", + "# res : 224\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 224" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08ba15b4", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC\n", + "# res : 256\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 256" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3cf25ec3", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73b9d9d3", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC_Grayscale\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_grayscale\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce16353c", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC_no_red\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_no_red\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "959ab837", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC_no_green\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_no_green\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "879beb46", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC_no_blue\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_no_blue\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d545dce", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC_red_only\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_red_only\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25480226", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC_green_only\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_green_only\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a064d169", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC_blue_only\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_blue_only\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d53828a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Namespace(batch_size=32, dataroot='/home/feoktistovar67431/data/isbi2019/CNMC', device='cuda', epochs=100, out='results/20220221T191207Z.OEHG', res=450, seed=30042022)\n", + "Model parameters: 25512945\n", + "Trainset length: 10625\n", + "Validset length: 1867\n", + "class_weights = tensor([0.3156, 0.6844], device='cuda:0')\n", + "/home/feoktistovar67431/.local/lib/python3.6/site-packages/sklearn/metrics/_classification.py:1248: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n", + "epoch=0 f1=0.1789\n", + "WARNING:root:NaN or Inf found in input tensor.\n", + "Epoch: 0%| | 0/100 [00:00.*(?P[a-zA-Z0-9_]+)/' + r'(?Phem|all)/' + r'UID_(?PH?\d+)_(?P\d+)_(?P\d+)_(all|hem).bmp)') + for file in file_iter(dataroot): + match = pattern.match(file) + if match is not None: + yield file, match + + +def to_dataframe(dataroot): + data = defaultdict(list) + keys = ['file', 'fold', 'subject', 'class', 'image', 'cell'] + + # Load data from the three training folds + for file, match in file_match_iter(dataroot): + for key in keys: + data[key].append(match.group(key)) + + # Load data from the phase2 validation set + phase2 = pd.read_csv(join(dataroot, 'phase2.csv'), header=0, names=['file_id', 'file', 'class']) + pattern = re.compile(r'UID_(?PH?\d+)_(?P\d+)_(?P\d+)_(all|hem).bmp') + for i, row in phase2.iterrows(): + match = pattern.match(row['file_id']) + data['file'].append(join(dataroot, f'phase2/{i+1}.bmp')) + data['fold'].append('3') + data['subject'].append(match.group('subject')) + data['class'].append('hem' if row['class'] == 0 else 'all') + data['image'].append(match.group('image')) + data['cell'].append(match.group('cell')) + + # Convert to dataframe + df = pd.DataFrame(data) + df = df.apply(pd.to_numeric, errors='ignore') + return df + + +class ISBI2019(Dataset): + def __init__(self, df, transform=None): + super().__init__() + self.transform = transform + self.df = df + + def __len__(self): + return len(self.df) + + def __getitem__(self, index): + # Convert tensors to int because pandas screws up otherwise + index = int(index) + file, cls = self.df.iloc[index][['file', 'class']] + img = Image.open(file)#.convert('RGB') + cls = 0 if cls == 'hem' else 1 + if self.transform is not None: + img = self.transform(img) + return img, cls + + +def get_class_weights(df): + class_weights = torch.FloatTensor([ + df.loc[df['class'] == 'hem']['file'].count() / len(df), + df.loc[df['class'] == 'all']['file'].count() / len(df), + ]).to(dtype=torch.float32) + return class_weights + + +def tf_rotation_stack(x, num_rotations=8): + xs = [] + for i in range(num_rotations): + angle = 360 * i / num_rotations + xrot = TF.rotate(x, angle) + xrot = TF.to_tensor(xrot) + xs.append(xrot) + xs = torch.stack(xs) + return xs + + +def get_tf_train_transform(res): + size_factor = int(STD_RES/res) + center_crop = int(STD_CENTER_CROP/size_factor) + tf_train = transforms.Compose([ + transforms.Resize(res), + #transforms.CenterCrop(center_crop), + transforms.RandomVerticalFlip(), + transforms.RandomHorizontalFlip(), + transforms.RandomAffine(degrees=360, translate=(0.2, 0.2)), + # transforms.Lambda(tf_rotation_stack), + transforms.ToTensor(), + ]) + return tf_train + + +def get_tf_vaild_rot_transform(res): + size_factor = int(STD_RES/res) + center_crop = int(STD_CENTER_CROP/size_factor) + tf_valid_rot = transforms.Compose([ + transforms.Resize(res), + #transforms.CenterCrop(center_crop), + transforms.Lambda(tf_rotation_stack), + ]) + return tf_valid_rot + + +def get_tf_valid_norot_transform(res): + size_factor = int(STD_RES/res) + center_crop = int(STD_CENTER_CROP/size_factor) + tf_valid_norot = transforms.Compose([ + transforms.Resize(res), + #transforms.CenterCrop(center_crop), + transforms.ToTensor(), + ]) + return tf_valid_norot + + +def get_dataset(dataroot, folds_train=(0, 1, 2), folds_valid=(3,), tf_train=None, tf_valid=None): + if tf_train is None or tf_valid is None: + sys.exit("Tranformation is None") + df = to_dataframe(dataroot) + df_trainset = df.loc[df['fold'].isin(folds_train)] + trainset = ISBI2019(df_trainset, transform=tf_train) + class_weights = get_class_weights(df_trainset) + + if folds_valid is not None: + df_validset = df.loc[df['fold'].isin(folds_valid)] + validset_subjects = df_validset['subject'].values + validset = ISBI2019(df_validset, transform=tf_valid) + return trainset, validset, validset_subjects, class_weights + else: + return trainset, class_weights + + +if __name__ == '__main__': + import math + from tqdm import tqdm + + df = to_dataframe('data') + print(df) + print("Examples by fold and class") + print(df.groupby(['fold', 'class'])['file'].count()) + + dataset = ISBI2019(df) + mean_height, mean_width = 0, 0 + weird_files = [] + bound_left, bound_upper, bound_right, bound_lower = math.inf, math.inf, 0, 0 + for i, (img, label) in tqdm(enumerate(dataset), total=len(dataset)): + left, upper, right, lower = img.getbbox() + if left == 0 or upper == 0 or right == 450 or lower == 450: + weird_files.append(df.iloc[i]['file']) + height = lower - upper + width = right - left + mean_height = mean_height + (height - mean_height) / (i + 1) + mean_width = mean_width + (width - mean_width) / (i + 1) + bound_left = min(bound_left, left) + bound_upper = min(bound_upper, upper) + bound_right = max(bound_right, right) + bound_lower = max(bound_lower, lower) + print(f"mean_height = {mean_height:.2f}") + print(f"mean_width = {mean_width:.2f}") + print(f"bound_left = {bound_left:d}") + print(f"bound_upper = {bound_upper:d}") + print(f"bound_right = {bound_right:d}") + print(f"bound_lower = {bound_lower:d}") + print("Files that max out at least one border:") + for f in weird_files: + print(f) diff --git a/main_manual.py b/main_manual.py new file mode 100644 index 0000000..dc2245b --- /dev/null +++ b/main_manual.py @@ -0,0 +1,246 @@ +import argparse +import os +from collections import defaultdict + +import numpy as np +import torch +import torch.nn.functional as F +from sklearn.metrics import roc_auc_score, confusion_matrix, precision_recall_fscore_support, accuracy_score +from tensorboardX import SummaryWriter +from torch.optim.lr_scheduler import StepLR, LambdaLR +from torch.utils.data import DataLoader +from tqdm import tqdm, trange + +from dataset import get_dataset, get_tf_train_transform, get_tf_vaild_rot_transform +from model import get_model +from utils import IncrementalAverage, to_device, set_seeds, unique_string, count_parameters + + +def evaluate(model, valid_loader, class_weights, device): + model.eval() + + all_labels = [] + all_preds = [] + loss_avg = IncrementalAverage() + for img, label in tqdm(valid_loader, leave=False): + img, label = to_device(device, img, label) + bs, nrot, c, h, w = img.size() + with torch.no_grad(): + pred = model(img.view(-1, c, h, w)) + pred = pred.view(bs, nrot).mean(1) + loss = lossfn(pred, label.to(pred.dtype), class_weights) + all_labels.append(label.cpu()) + all_preds.append(pred.cpu()) + loss_avg.update(loss.item()) + + all_labels = torch.cat(all_labels).numpy() + all_preds = torch.cat(all_preds).numpy() + all_preds_binary = all_preds > 0 + + cm = confusion_matrix(all_labels, all_preds_binary) + auc = roc_auc_score(all_labels, all_preds) + prec, rec, f1, _ = precision_recall_fscore_support(all_labels, all_preds_binary, average='weighted') + return loss_avg.value, cm, auc, prec, rec, f1 + + +def train(model, opt, train_loader, class_weights, device): + model.train() + loss_avg = IncrementalAverage() + for img, label in tqdm(train_loader, leave=False): + img, label = to_device(device, img, label) + pred = model(img) + pred = pred.view(-1) + loss = lossfn(pred, label.to(pred.dtype), class_weights) + loss_avg.update(loss.item()) + + opt.zero_grad() + loss.backward() + opt.step() + return loss_avg.value + + +def lossfn(prediction, target, class_weights): + pos_weight = (class_weights[0] / class_weights[1]).expand(len(target)) + return F.binary_cross_entropy_with_logits(prediction, target, pos_weight=pos_weight) + + +def schedule(epoch): + if epoch < 2: + ub = 1 + elif epoch < 4: + ub = 0.1 + else: + ub = 0.01 + return ub + + +def train_validate(args): + model = get_model().to(args.device) + print("Model parameters:", count_parameters(model)) + + trainset, validset, validset_subjects, class_weights = get_dataset(args.dataroot, + tf_train=get_tf_train_transform(args.res), + tf_valid=get_tf_vaild_rot_transform(args.res)) + class_weights = class_weights.to(args.device) + print(f"Trainset length: {len(trainset)}") + print(f"Validset length: {len(validset)}") + print(f"class_weights = {class_weights}") + + train_loader = DataLoader(trainset, batch_size=args.batch_size, num_workers=6, shuffle=True, drop_last=True) + valid_loader = DataLoader(validset, batch_size=args.batch_size, num_workers=6, shuffle=False) + + opt = torch.optim.Adam([ + {'params': model.paramgroup01(), 'lr': 1e-6}, + {'params': model.paramgroup234(), 'lr': 1e-4}, + {'params': model.parameters_classifier(), 'lr': 1e-2}, + ]) + scheduler = LambdaLR(opt, lr_lambda=[lambda e: schedule(e), + lambda e: schedule(e), + lambda e: schedule(e)]) + + summarywriter = SummaryWriter(args.out) + recorded_data = defaultdict(list) + + def logged_eval(e): + valid_loss, cm, auc, prec, rec, f1 = evaluate(model, valid_loader, class_weights, args.device) + + # Derive some accuracy metrics from confusion matrix + tn, fp, fn, tp = cm.ravel() + acc = (tp + tn) / cm.sum() + acc_hem = tn / (tn + fp) + acc_all = tp / (tp + fn) + + print(f"epoch={e} f1={f1:.4f}") + + summarywriter.add_scalar('loss/train', train_loss, e) + summarywriter.add_scalar('loss/valid', valid_loss, e) + summarywriter.add_scalar('cm/tn', tn, e) + summarywriter.add_scalar('cm/fp', fp, e) + summarywriter.add_scalar('cm/fn', fn, e) + summarywriter.add_scalar('cm/tp', tp, e) + summarywriter.add_scalar('metrics/precision', prec, e) + summarywriter.add_scalar('metrics/recall', rec, e) + summarywriter.add_scalar('metrics/f1', f1, e) + summarywriter.add_scalar('metrics/auc', auc, e) + summarywriter.add_scalar('acc/acc', acc, e) + summarywriter.add_scalar('acc/hem', acc_hem, e) + summarywriter.add_scalar('acc/all', acc_all, e) + + recorded_data['loss_train'].append(train_loss) + recorded_data['loss_valid'].append(valid_loss) + recorded_data['tn'].append(tn) + recorded_data['tn'].append(tn) + recorded_data['fp'].append(fp) + recorded_data['fn'].append(fn) + recorded_data['tp'].append(tp) + recorded_data['precision'].append(prec) + recorded_data['recall'].append(rec) + recorded_data['f1'].append(f1) + recorded_data['auc'].append(auc) + recorded_data['acc'].append(acc) + recorded_data['acc_hem'].append(acc_hem) + recorded_data['acc_all'].append(acc_all) + np.savez(f'{args.out}/results', **recorded_data) + + return f1 + + model = torch.nn.DataParallel(model) + train_loss = np.nan + best_val_f1 = logged_eval(0) + for e in trange(args.epochs, desc='Epoch'): + scheduler.step(e) + train_loss = train(model, opt, train_loader, class_weights, args.device) + val_f1 = logged_eval(e + 1) + + if val_f1 > best_val_f1: + print(f"New best model at {val_f1:.6f}") + torch.save(model.state_dict(), f'{args.out}/model.pt') + best_val_f1 = val_f1 + + summarywriter.close() + + subj_acc = evaluate_subj_acc(model, validset, validset_subjects, args.device) + np.savez(f'{args.out}/subj_acc', **subj_acc) + + +def evaluate_subj_acc(model, dataset, subjects, device): + model.eval() + + subj_pred = defaultdict(list) + subj_label = defaultdict(list) + + dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False) + + for (img, cls), subj in tqdm(zip(dataloader, subjects), total=len(subjects), leave=False): + img, cls = to_device(device, img, cls) + bs, nrot, c, h, w = img.size() + with torch.no_grad(): + cls_hat = model(img.view(-1, c, h, w)) + cls_hat = cls_hat.view(bs, nrot).mean(1) + subj_label[subj].append(cls.cpu()) + subj_pred[subj].append(cls_hat.cpu()) + + for k in subj_label: + subj_label[k] = torch.cat(subj_label[k]).numpy() + subj_pred[k] = torch.cat(subj_pred[k]).numpy() > 0 + + subj_acc = {} + for k in subj_label: + subj_acc[k] = accuracy_score(subj_label[k], subj_pred[k]) + + return subj_acc + + +def train_test(args): + model = get_model().to(args.device) + print("Model parameters:", count_parameters(model)) + + trainset, class_weights = get_dataset(args.dataroot, folds_train=(0, 1, 2, 3), + folds_valid=None, + tf_train=get_tf_train_transform(args.res), + tf_valid=get_tf_vaild_rot_transform(args.res)) + class_weights = class_weights.to(args.device) + print(f"Trainset length: {len(trainset)}") + print(f"class_weights = {class_weights}") + + train_loader = DataLoader(trainset, batch_size=args.batch_size, num_workers=6, shuffle=True, drop_last=True) + + opt = torch.optim.Adam([ + {'params': model.paramgroup01(), 'lr': 1e-6}, + {'params': model.paramgroup234(), 'lr': 1e-4}, + {'params': model.parameters_classifier(), 'lr': 1e-2}, + ]) + scheduler = LambdaLR(opt, lr_lambda=[lambda e: schedule(e), + lambda e: schedule(e), + lambda e: schedule(e)]) + + model = torch.nn.DataParallel(model) + for e in trange(args.epochs, desc='Epoch'): + scheduler.step(e) + train(model, opt, train_loader, class_weights, args.device) + torch.save(model.state_dict(), f'{args.out}/model.pt') + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--dataroot', default='data', help='path to dataset') + parser.add_argument('--batch-size', type=int, default=16) + parser.add_argument('--epochs', type=int, default=6) + parser.add_argument('--seed', default=1, type=int, help='random seed') + parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') + parser.add_argument('--out', default='results', help='output folder') + parser.add_argument('--res', type=int, default='450', help='Desired input resolution') + args = parser.parse_args() + args.out = os.path.join(args.out, unique_string()) + return args + + +if __name__ == '__main__': + args = parse_args() + print(args) + + os.makedirs(args.out, exist_ok=True) + set_seeds(args.seed) + torch.backends.cudnn.benchmark = True + + train_validate(args) diff --git a/main_manual_abl_layerlr.py b/main_manual_abl_layerlr.py new file mode 100644 index 0000000..c7bf4cd --- /dev/null +++ b/main_manual_abl_layerlr.py @@ -0,0 +1,210 @@ +import argparse +import os +from collections import defaultdict + +import numpy as np +import torch +import torch.nn.functional as F +from sklearn.metrics import roc_auc_score, confusion_matrix, precision_recall_fscore_support, accuracy_score +from tensorboardX import SummaryWriter +from torch.optim.lr_scheduler import StepLR, LambdaLR +from torch.utils.data import DataLoader +from tqdm import tqdm, trange + +from dataset import get_dataset, get_tf_train_transform, get_tf_vaild_rot_transform +from model import get_model +from utils import IncrementalAverage, to_device, set_seeds, unique_string, count_parameters + + +def evaluate(model, valid_loader, class_weights, device): + model.eval() + + all_labels = [] + all_preds = [] + loss_avg = IncrementalAverage() + for img, label in tqdm(valid_loader, leave=False): + img, label = to_device(device, img, label) + bs, nrot, c, h, w = img.size() + with torch.no_grad(): + pred = model(img.view(-1, c, h, w)) + pred = pred.view(bs, nrot).mean(1) + loss = lossfn(pred, label.to(pred.dtype), class_weights) + all_labels.append(label.cpu()) + all_preds.append(pred.cpu()) + loss_avg.update(loss.item()) + + all_labels = torch.cat(all_labels).numpy() + all_preds = torch.cat(all_preds).numpy() + all_preds_binary = all_preds > 0 + + cm = confusion_matrix(all_labels, all_preds_binary) + auc = roc_auc_score(all_labels, all_preds) + prec, rec, f1, _ = precision_recall_fscore_support(all_labels, all_preds_binary, average='weighted') + return loss_avg.value, cm, auc, prec, rec, f1 + + +def train(model, opt, train_loader, class_weights, device): + model.train() + loss_avg = IncrementalAverage() + for img, label in tqdm(train_loader, leave=False): + img, label = to_device(device, img, label) + pred = model(img) + pred = pred.view(-1) + loss = lossfn(pred, label.to(pred.dtype), class_weights) + loss_avg.update(loss.item()) + + opt.zero_grad() + loss.backward() + opt.step() + return loss_avg.value + + +def lossfn(prediction, target, class_weights): + pos_weight = (class_weights[0] / class_weights[1]).expand(len(target)) + return F.binary_cross_entropy_with_logits(prediction, target, pos_weight=pos_weight) + + +def schedule(epoch): + if epoch < 2: + ub = 1 + elif epoch < 4: + ub = 0.1 + else: + ub = 0.01 + return ub + + +def train_validate(args): + model = get_model().to(args.device) + print("Model parameters:", count_parameters(model)) + + trainset, validset, validset_subjects, class_weights = get_dataset(args.dataroot, + tf_train=get_tf_train_transform(args.res), + tf_valid=get_tf_vaild_rot_transform(args.res)) + class_weights = class_weights.to(args.device) + print(f"Trainset length: {len(trainset)}") + print(f"Validset length: {len(validset)}") + print(f"class_weights = {class_weights}") + + train_loader = DataLoader(trainset, batch_size=args.batch_size, num_workers=6, shuffle=True, drop_last=True) + valid_loader = DataLoader(validset, batch_size=args.batch_size, num_workers=6, shuffle=False) + + opt = torch.optim.Adam([ + {'params': model.paramgroup01(), 'lr': args.lr}, + {'params': model.paramgroup234(), 'lr': args.lr}, + {'params': model.parameters_classifier(), 'lr': args.lr}, + ]) + scheduler = LambdaLR(opt, lr_lambda=[lambda e: schedule(e), + lambda e: schedule(e), + lambda e: schedule(e)]) + + summarywriter = SummaryWriter(args.out) + recorded_data = defaultdict(list) + + def logged_eval(e): + valid_loss, cm, auc, prec, rec, f1 = evaluate(model, valid_loader, class_weights, args.device) + + # Derive some accuracy metrics from confusion matrix + tn, fp, fn, tp = cm.ravel() + acc = (tp + tn) / cm.sum() + acc_hem = tn / (tn + fp) + acc_all = tp / (tp + fn) + + print(f"epoch={e} f1={f1:.4f}") + + summarywriter.add_scalar('loss/train', train_loss, e) + summarywriter.add_scalar('loss/valid', valid_loss, e) + summarywriter.add_scalar('cm/tn', tn, e) + summarywriter.add_scalar('cm/fp', fp, e) + summarywriter.add_scalar('cm/fn', fn, e) + summarywriter.add_scalar('cm/tp', tp, e) + summarywriter.add_scalar('metrics/precision', prec, e) + summarywriter.add_scalar('metrics/recall', rec, e) + summarywriter.add_scalar('metrics/f1', f1, e) + summarywriter.add_scalar('metrics/auc', auc, e) + summarywriter.add_scalar('acc/acc', acc, e) + summarywriter.add_scalar('acc/hem', acc_hem, e) + summarywriter.add_scalar('acc/all', acc_all, e) + + recorded_data['loss_train'].append(train_loss) + recorded_data['loss_valid'].append(valid_loss) + recorded_data['tn'].append(tn) + recorded_data['tn'].append(tn) + recorded_data['fp'].append(fp) + recorded_data['fn'].append(fn) + recorded_data['tp'].append(tp) + recorded_data['precision'].append(prec) + recorded_data['recall'].append(rec) + recorded_data['f1'].append(f1) + recorded_data['auc'].append(auc) + recorded_data['acc'].append(acc) + recorded_data['acc_hem'].append(acc_hem) + recorded_data['acc_all'].append(acc_all) + np.savez(f'{args.out}/results', **recorded_data) + + model = torch.nn.DataParallel(model) + train_loss = np.nan + logged_eval(0) + for e in trange(args.epochs, desc='Epoch'): + scheduler.step(e) + train_loss = train(model, opt, train_loader, class_weights, args.device) + logged_eval(e + 1) + + summarywriter.close() + + subj_acc = evaluate_subj_acc(model, validset, validset_subjects, args.device) + np.savez(f'{args.out}/subj_acc', **subj_acc) + + +def evaluate_subj_acc(model, dataset, subjects, device): + model.eval() + + subj_pred = defaultdict(list) + subj_label = defaultdict(list) + + dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False) + + for (img, cls), subj in tqdm(zip(dataloader, subjects), total=len(subjects), leave=False): + img, cls = to_device(device, img, cls) + bs, nrot, c, h, w = img.size() + with torch.no_grad(): + cls_hat = model(img.view(-1, c, h, w)) + cls_hat = cls_hat.view(bs, nrot).mean(1) + subj_label[subj].append(cls.cpu()) + subj_pred[subj].append(cls_hat.cpu()) + + for k in subj_label: + subj_label[k] = torch.cat(subj_label[k]).numpy() + subj_pred[k] = torch.cat(subj_pred[k]).numpy() > 0 + + subj_acc = {} + for k in subj_label: + subj_acc[k] = accuracy_score(subj_label[k], subj_pred[k]) + + return subj_acc + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--dataroot', default='data', help='path to dataset') + parser.add_argument('--lr', type=float, default=1e-4) + parser.add_argument('--batch-size', type=int, default=16) + parser.add_argument('--epochs', type=int, default=6) + parser.add_argument('--seed', default=1, type=int, help='random seed') + parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') + parser.add_argument('--out', default='results', help='output folder') + parser.add_argument('--res', type=int, default='450', help='Desired input resolution') + args = parser.parse_args() + args.out = os.path.join(args.out, unique_string()) + return args + + +if __name__ == '__main__': + args = parse_args() + print(args) + + os.makedirs(args.out, exist_ok=True) + set_seeds(args.seed) + torch.backends.cudnn.benchmark = True + + train_validate(args) diff --git a/main_manual_abl_testrot.py b/main_manual_abl_testrot.py new file mode 100644 index 0000000..8426496 --- /dev/null +++ b/main_manual_abl_testrot.py @@ -0,0 +1,208 @@ +import argparse +import os +from collections import defaultdict + +import numpy as np +import torch +import torch.nn.functional as F +from sklearn.metrics import roc_auc_score, confusion_matrix, precision_recall_fscore_support, accuracy_score +from tensorboardX import SummaryWriter +from torch.optim.lr_scheduler import StepLR, LambdaLR +from torch.utils.data import DataLoader +from tqdm import tqdm, trange + +from dataset import get_dataset, get_tf_valid_norot_transform, get_tf_train_transform +from model import get_model +from utils import IncrementalAverage, to_device, set_seeds, unique_string, count_parameters + + +def evaluate(model, valid_loader, class_weights, device): + model.eval() + + all_labels = [] + all_preds = [] + loss_avg = IncrementalAverage() + for img, label in tqdm(valid_loader, leave=False): + img, label = to_device(device, img, label) + with torch.no_grad(): + pred = model(img).view(-1) + loss = lossfn(pred, label.to(pred.dtype), class_weights) + all_labels.append(label.cpu()) + all_preds.append(pred.cpu()) + loss_avg.update(loss.item()) + + all_labels = torch.cat(all_labels).numpy() + all_preds = torch.cat(all_preds).numpy() + all_preds_binary = all_preds > 0 + + cm = confusion_matrix(all_labels, all_preds_binary) + auc = roc_auc_score(all_labels, all_preds) + prec, rec, f1, _ = precision_recall_fscore_support(all_labels, all_preds_binary, average='weighted') + return loss_avg.value, cm, auc, prec, rec, f1 + + +def train(model, opt, train_loader, class_weights, device): + model.train() + loss_avg = IncrementalAverage() + for img, label in tqdm(train_loader, leave=False): + img, label = to_device(device, img, label) + pred = model(img) + pred = pred.view(-1) + loss = lossfn(pred, label.to(pred.dtype), class_weights) + loss_avg.update(loss.item()) + + opt.zero_grad() + loss.backward() + opt.step() + return loss_avg.value + + +def lossfn(prediction, target, class_weights): + pos_weight = (class_weights[0] / class_weights[1]).expand(len(target)) + return F.binary_cross_entropy_with_logits(prediction, target, pos_weight=pos_weight) + + +def schedule(epoch): + if epoch < 2: + ub = 1 + elif epoch < 4: + ub = 0.1 + else: + ub = 0.01 + return ub + + +def train_validate(args): + model = get_model().to(args.device) + print("Model parameters:", count_parameters(model)) + + trainset, validset, validset_subjects, class_weights = get_dataset(args.dataroot, + tf_valid=get_tf_valid_norot_transform(args.res), + tf_train=get_tf_train_transform(args.res)) + class_weights = class_weights.to(args.device) + print(f"Trainset length: {len(trainset)}") + print(f"Validset length: {len(validset)}") + print(f"class_weights = {class_weights}") + + train_loader = DataLoader(trainset, batch_size=args.batch_size, num_workers=6, shuffle=True, drop_last=True) + valid_loader = DataLoader(validset, batch_size=args.batch_size, num_workers=6, shuffle=False) + + opt = torch.optim.Adam([ + {'params': model.paramgroup01(), 'lr': 1e-6}, + {'params': model.paramgroup234(), 'lr': 1e-4}, + {'params': model.parameters_classifier(), 'lr': 1e-2}, + ]) + scheduler = LambdaLR(opt, lr_lambda=[lambda e: schedule(e), + lambda e: schedule(e), + lambda e: schedule(e)]) + + summarywriter = SummaryWriter(args.out) + recorded_data = defaultdict(list) + + def logged_eval(e): + valid_loss, cm, auc, prec, rec, f1 = evaluate(model, valid_loader, class_weights, args.device) + + # Derive some accuracy metrics from confusion matrix + tn, fp, fn, tp = cm.ravel() + acc = (tp + tn) / cm.sum() + acc_hem = tn / (tn + fp) + acc_all = tp / (tp + fn) + + print(f"epoch={e} f1={f1:.4f}") + + summarywriter.add_scalar('loss/train', train_loss, e) + summarywriter.add_scalar('loss/valid', valid_loss, e) + summarywriter.add_scalar('cm/tn', tn, e) + summarywriter.add_scalar('cm/fp', fp, e) + summarywriter.add_scalar('cm/fn', fn, e) + summarywriter.add_scalar('cm/tp', tp, e) + summarywriter.add_scalar('metrics/precision', prec, e) + summarywriter.add_scalar('metrics/recall', rec, e) + summarywriter.add_scalar('metrics/f1', f1, e) + summarywriter.add_scalar('metrics/auc', auc, e) + summarywriter.add_scalar('acc/acc', acc, e) + summarywriter.add_scalar('acc/hem', acc_hem, e) + summarywriter.add_scalar('acc/all', acc_all, e) + + recorded_data['loss_train'].append(train_loss) + recorded_data['loss_valid'].append(valid_loss) + recorded_data['tn'].append(tn) + recorded_data['tn'].append(tn) + recorded_data['fp'].append(fp) + recorded_data['fn'].append(fn) + recorded_data['tp'].append(tp) + recorded_data['precision'].append(prec) + recorded_data['recall'].append(rec) + recorded_data['f1'].append(f1) + recorded_data['auc'].append(auc) + recorded_data['acc'].append(acc) + recorded_data['acc_hem'].append(acc_hem) + recorded_data['acc_all'].append(acc_all) + np.savez(f'{args.out}/results', **recorded_data) + + model = torch.nn.DataParallel(model) + train_loss = np.nan + logged_eval(0) + for e in trange(args.epochs, desc='Epoch'): + scheduler.step(e) + train_loss = train(model, opt, train_loader, class_weights, args.device) + logged_eval(e + 1) + + torch.save(model.state_dict(), f'{args.out}/model.pt') + summarywriter.close() + + subj_acc = evaluate_subj_acc(model, validset, validset_subjects, args.device) + np.savez(f'{args.out}/subj_acc', **subj_acc) + + +def evaluate_subj_acc(model, dataset, subjects, device): + model.eval() + + subj_pred = defaultdict(list) + subj_label = defaultdict(list) + + dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False) + + for (img, cls), subj in tqdm(zip(dataloader, subjects), total=len(subjects), leave=False): + img, cls = to_device(device, img, cls) + bs, nrot, c, h, w = img.size() + with torch.no_grad(): + cls_hat = model(img.view(-1, c, h, w)) + cls_hat = cls_hat.view(bs, nrot).mean(1) + subj_label[subj].append(cls.cpu()) + subj_pred[subj].append(cls_hat.cpu()) + + for k in subj_label: + subj_label[k] = torch.cat(subj_label[k]).numpy() + subj_pred[k] = torch.cat(subj_pred[k]).numpy() > 0 + + subj_acc = {} + for k in subj_label: + subj_acc[k] = accuracy_score(subj_label[k], subj_pred[k]) + + return subj_acc + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--dataroot', default='data', help='path to dataset') + parser.add_argument('--batch-size', type=int, default=16) + parser.add_argument('--epochs', type=int, default=6) + parser.add_argument('--seed', default=1, type=int, help='random seed') + parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') + parser.add_argument('--out', default='results', help='output folder') + parser.add_argument('--res', type=int, default='450', help='Desired input resolution') + args = parser.parse_args() + args.out = os.path.join(args.out, unique_string()) + return args + + +if __name__ == '__main__': + args = parse_args() + print(args) + + os.makedirs(args.out, exist_ok=True) + set_seeds(args.seed) + torch.backends.cudnn.benchmark = True + + train_validate(args) diff --git a/model.py b/model.py new file mode 100644 index 0000000..4565eea --- /dev/null +++ b/model.py @@ -0,0 +1,183 @@ +# Code adapted from: https://github.com/Cadene/pretrained-models.pytorch +import math +from collections import OrderedDict +from itertools import chain + +import torch.nn as nn +from torch.utils import model_zoo + +from utils import Flatten + + +class SEModule(nn.Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class SEResNeXtBottleneck(nn.Module): + """ + ResNeXt bottleneck type C with a Squeeze-and-Excitation module. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None, base_width=4): + super(SEResNeXtBottleneck, self).__init__() + width = math.floor(planes * (base_width / 64)) * groups + self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = self.se_module(out) + residual + out = self.relu(out) + + return out + + +class SENet(nn.Module): + def __init__(self, block, layers, groups, reduction, inplanes=128, + downsample_kernel_size=3, downsample_padding=1): + super(SENet, self).__init__() + self.inplanes = inplanes + + layer0_modules = [ + ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), + ('bn1', nn.BatchNorm2d(inplanes)), + ('relu1', nn.ReLU(inplace=True)), + # To preserve compatibility with Caffe weights `ceil_mode=True` + # is used instead of `padding=1`. + ('pool', nn.MaxPool2d(3, stride=2, ceil_mode=True)) + ] + self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) + self.layer1 = self._make_layer( + block, + planes=64, + blocks=layers[0], + groups=groups, + reduction=reduction, + downsample_kernel_size=1, + downsample_padding=0 + ) + self.layer2 = self._make_layer( + block, + planes=128, + blocks=layers[1], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.layer3 = self._make_layer( + block, + planes=256, + blocks=layers[2], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.layer4 = self._make_layer( + block, + planes=512, + blocks=layers[3], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.cls = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + Flatten(), + nn.Linear(512 * block.expansion, 1) + ) + + def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, + downsample_kernel_size=1, downsample_padding=0): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=downsample_kernel_size, stride=stride, + padding=downsample_padding, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, groups, reduction, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups, reduction)) + + return nn.Sequential(*layers) + + def paramgroup01(self): + return chain( + self.layer0.parameters(), + self.layer1.parameters(), + ) + + def paramgroup234(self): + return chain( + self.layer2.parameters(), + self.layer3.parameters(), + self.layer4.parameters(), + ) + + def parameters_classifier(self): + return self.cls.parameters() + + def forward(self, x): + x = self.layer0(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + c = self.cls(x) + return c + + +def get_model(): + model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, inplanes=64, + downsample_kernel_size=1, downsample_padding=0) + checkpoint = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth') + model.load_state_dict(checkpoint, strict=False) + return model diff --git a/plot.py b/plot.py new file mode 100644 index 0000000..149a10d --- /dev/null +++ b/plot.py @@ -0,0 +1,166 @@ +from glob import glob + +import numpy as np +import matplotlib.pyplot as plt +from os.path import join + +from scipy.stats import mannwhitneyu + +dataroots = { + 'PROPOSAL' : 'results', + #'model_cnmc_res_128' : 'results/model_cnmc_res_128', + #'model_cnmc_res_224' : 'results/model_cnmc_res_224', + #'model_cnmc_res_256' : 'results/model_cnmc_res_256', + #'model_cnmc_res_450' : 'results/model_cnmc_res_450', + #'model_cnmc_res_450_blue_only' : 'results/model_cnmc_res_450_blue_only', + #'model_cnmc_res_450_green_only' : 'results/model_cnmc_res_450_green_only', + #'model_cnmc_res_450_red_only' : 'results/model_cnmc_res_450_red_only', + #'model_cnmc_res_450_no_blue' : 'results/model_cnmc_res_450_no_blue', + #'model_cnmc_res_450_no_green' : 'results/model_cnmc_res_450_no_green', + #'model_cnmc_res_450_no_red' : 'results/model_cnmc_res_450_no_red', + #'model_cnmc_res_450_grayscale' : 'results/model_cnmc_res_450_grayscale', +} + + + + +def get_values(dataroot, key): + npzs = list(glob(join(dataroot, '*', 'results.npz'))) + vals = [] + for f in npzs: + recorded_data = np.load(f) + val = recorded_data[key] + vals.append(val) + vals = np.stack(vals, 0) + return vals + + +def plot_mean_std(dataroot, key, ax, **kwargs): + vals = get_values(dataroot, key) + mean = np.mean(vals, 0) + std = np.std(vals, 0) + epochs = np.arange(len(mean)) + + # Offset by 1 so that we have nicely zoomed plots + mean = mean[1:] + std = std[1:] + epochs = epochs[1:] + + ax.plot(epochs, mean, **kwargs) + ax.fill_between(epochs, mean - std, mean + std, alpha=0.2) + + +def plot3(key, ax): + for k, v in dataroots.items(): + plot_mean_std(v, key, ax, label=k) + + +def print_final_min_mean_max(dataroot, key, model_epochs): + vals = get_values(dataroot, key) * 100 + vals = vals[np.arange(len(vals)), model_epochs] + min = np.min(vals) + mean = np.mean(vals) + std = np.std(vals) + max = np.max(vals) + print(f'{min:.2f}', f'{mean:.2f} ± {std:.2f}', f'{max:.2f}', sep='\t') + + +def print_final_table(dataroot): + best_model_epochs = np.argmax(get_values(dataroot, 'f1'), axis=1) + + print_final_min_mean_max(dataroot, 'acc', best_model_epochs) + print_final_min_mean_max(dataroot, 'acc_all', best_model_epochs) + print_final_min_mean_max(dataroot, 'acc_hem', best_model_epochs) + print_final_min_mean_max(dataroot, 'f1', best_model_epochs) + print_final_min_mean_max(dataroot, 'precision', best_model_epochs) + print_final_min_mean_max(dataroot, 'recall', best_model_epochs) + + +def get_best_f1_scores(dataroot): + f1_scores = get_values(dataroot, 'f1') + best_model_epochs = np.argmax(f1_scores, axis=1) + return f1_scores[np.arange(len(f1_scores)), best_model_epochs] + + +def is_statistically_greater(dataroot1, dataroot2): + # Tests if F1-score of dataroot1 is greater than dataroot2 + a = get_best_f1_scores(dataroot1) + b = get_best_f1_scores(dataroot2) + u, p = mannwhitneyu(a, b, alternative='greater') + return u, p + + +###### + +for k, v in dataroots.items(): + print(k) + print_final_table(v) + print() + + +###### + +#print("MWU-Test of PROPOSAL > NOSPECLR") +#print(is_statistically_greater(dataroots['PROPOSAL'], dataroots['NOSPECLR'])) +#print() +#print("MWU-Test of PROPOSAL > NOROT") +#print(is_statistically_greater(dataroots['PROPOSAL'], dataroots['NOROT'])) + +###### + +fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(9, 5)) + +ax[0, 0].set_title('Accuracy') +plot3('acc', ax[0, 0]) + +ax[0, 1].set_title('Sensitivity') +plot3('acc_all', ax[0, 1]) + +ax[0, 2].set_title('Specificity') +plot3('acc_hem', ax[0, 2]) + +ax[1, 0].set_title('F1 score') +plot3('f1', ax[1, 0]) + +ax[1, 1].set_title('Precision') +plot3('precision', ax[1, 1]) + +ax[1, 2].set_title('Recall') +plot3('recall', ax[1, 2]) + +fig.legend(loc='lower center', ncol=3) +fig.tight_layout() +fig.subplots_adjust(bottom=0.12) +fig.savefig('results/plot_ablations.pdf') + +###### +npload= 'results/model_cnmc_res_128' +npload_sub=npload + '/subj_acc.npz' +npload_res=npload + '/results.npz' +subj_acc = np.load(npload_sub) +subj = list(sorted(subj_acc.keys())) +acc = [subj_acc[k] for k in subj] +fig, ax = plt.subplots(figsize=(9, 2)) +ax.bar(range(len(acc)), acc, width=0.3, tick_label=subj) +fig.tight_layout() +fig.savefig('results/plot_subj_acc.pdf') + +###### + +data = np.load(npload_res) +loss_train = data['loss_train'] +loss_valid = data['loss_valid'][1:] +f1_valid = data['f1'][1:] +fig, ax = plt.subplots(ncols=3, figsize=(9, 2)) +ax[0].plot(range(len(loss_train)), loss_train) +ax[0].set_title("Training set loss") +ax[1].plot(range(1, len(loss_valid) + 1), loss_valid) +ax[1].set_title("Preliminary test set loss") +ax[2].plot(range(1, len(f1_valid) + 1), f1_valid) +ax[2].set_title("Preliminary test set F1-score") +fig.tight_layout() +fig.savefig('results/plot_curves.pdf') + +###### + +plt.show() diff --git a/run.ipynb b/run.ipynb new file mode 100644 index 0000000..4df8a69 --- /dev/null +++ b/run.ipynb @@ -0,0 +1,679 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "cd8aaf96", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install pandas tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26bd5e25", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "b753e6b8", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model\n", + "Classifying\n", + "59it [00:41, 1.43it/s] \n", + "Positive: 1234\n", + "Negative: 633\n", + "AUC: 0.8797024225483345\n" + ] + } + ], + "source": [ + "!python3 submission.py --modelroot \"/home/feoktistovar67431/isbi2019cancer-master/results/20220216T154306Z.AZHL\" --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC/phase2\" --batch-size 32" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3246460b", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC\n", + "# res : 32\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 32" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a953a39", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC\n", + "# res : 128\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 128" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12c15b33", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC\n", + "# res : 224\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 224" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08ba15b4", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC\n", + "# res : 256\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 256" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3cf25ec3", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73b9d9d3", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC_Grayscale\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_grayscale\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce16353c", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC_no_red\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_no_red\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "959ab837", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC_no_green\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_no_green\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "879beb46", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC_no_blue\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_no_blue\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d545dce", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC_red_only\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_red_only\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25480226", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC_green_only\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_green_only\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a064d169", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC_blue_only\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_blue_only\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d53828a", + "metadata": {}, + "outputs": [], + "source": [ + "# TRAIN\n", + "# dataset : CNMC\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 main_manual.py --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC\" --batch-size 32 --epochs 100 --seed 30042022 --device cuda --out results --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": 183, + "id": "ea9c2f23", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PROPOSAL\n", + "68.51\t83.57 ± 5.16\t89.61\n", + "84.33\t89.06 ± 2.09\t92.95\n", + "38.73\t73.26 ± 11.77\t84.72\n", + "66.76\t83.35 ± 5.61\t89.57\n", + "66.81\t83.36 ± 5.60\t89.55\n", + "68.51\t83.57 ± 5.16\t89.61\n", + "\n", + "Figure(900x500)\n", + "Figure(900x200)\n", + "Figure(900x200)\n" + ] + } + ], + "source": [ + "# PLOT\n", + "# dataset : CNMC\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 plot.py" + ] + }, + { + "cell_type": "markdown", + "id": "8c92073d", + "metadata": {}, + "source": [ + "# EVALUATION" + ] + }, + { + "cell_type": "code", + "execution_count": 135, + "id": "b25a4267", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model\n", + "Classifying\n", + "59it [00:05, 11.69it/s] \n", + "Positive: 1425\n", + "Negative: 442\n", + "AUC: 0.6153299354864846\n" + ] + } + ], + "source": [ + "# EVALUATION \n", + "# dataset : CNMC\n", + "# res : 32\n", + "# epochs : 100\n", + "!python3 submission.py --modelroot \"/home/feoktistovar67431/isbi2019cancer-master/results/model_cnmc_res_32\" --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC/phase2\" --batch-size 32 --res 32" + ] + }, + { + "cell_type": "code", + "execution_count": 136, + "id": "b14e3e67", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model\n", + "Classifying\n", + "59it [00:09, 6.24it/s] \n", + "Positive: 1315\n", + "Negative: 552\n", + "AUC: 0.7711131113339208\n" + ] + } + ], + "source": [ + "# EVALUATION\n", + "# dataset : CNMC\n", + "# res : 128\n", + "# epochs : 100\n", + "!python3 submission.py --modelroot \"/home/feoktistovar67431/isbi2019cancer-master/results/model_cnmc_res_128\" --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC/phase2\" --batch-size 32 --res 128" + ] + }, + { + "cell_type": "code", + "execution_count": 137, + "id": "dfb25744", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model\n", + "Classifying\n", + "59it [00:14, 4.19it/s] \n", + "Positive: 1262\n", + "Negative: 605\n", + "AUC: 0.8143717274835677\n" + ] + } + ], + "source": [ + "# EVALUATION\n", + "# dataset : CNMC\n", + "# res : 224\n", + "# epochs : 100\n", + "!python3 submission.py --modelroot \"/home/feoktistovar67431/isbi2019cancer-master/results/model_cnmc_res_224\" --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC/phase2\" --batch-size 32 --res 224" + ] + }, + { + "cell_type": "code", + "execution_count": 138, + "id": "68600db4", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model\n", + "Classifying\n", + "59it [00:41, 1.44it/s] \n", + "Positive: 1195\n", + "Negative: 672\n", + "AUC: 0.8400701597139936\n" + ] + } + ], + "source": [ + "# EVALUATION\n", + "# dataset : CNMC\n", + "# res : 256\n", + "# epochs : 100\n", + "!python3 submission.py --modelroot \"/home/feoktistovar67431/isbi2019cancer-master/results/model_cnmc_res_256\" --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC/phase2\" --batch-size 32 --res 256" + ] + }, + { + "cell_type": "code", + "execution_count": 139, + "id": "71a5547e", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model\n", + "Classifying\n", + "59it [00:41, 1.42it/s] \n", + "Positive: 1241\n", + "Negative: 626\n", + "AUC: 0.8813918512441892\n" + ] + } + ], + "source": [ + "# EVALUATION\n", + "# dataset : CNMC\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 submission.py --modelroot \"/home/feoktistovar67431/isbi2019cancer-master/results/model_cnmc_res_450\" --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC/phase2\" --batch-size 32 --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": 154, + "id": "58450362", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model\n", + "Classifying\n", + "59it [00:41, 1.42it/s] \n", + "Positive: 1261\n", + "Negative: 606\n", + "AUC: 0.8045073375262055\n" + ] + } + ], + "source": [ + "# EVALUATION\n", + "# dataset : CNMC_Grayscale\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 submission.py --modelroot \"/home/feoktistovar67431/isbi2019cancer-master/results/model_cnmc_res_450_grayscale\" --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_grayscale/phase2\" --batch-size 32 --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": 155, + "id": "48c40f18", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model\n", + "Classifying\n", + "59it [00:44, 1.33it/s] \n", + "Positive: 1178\n", + "Negative: 689\n", + "AUC: 0.8661869929814967\n" + ] + } + ], + "source": [ + "# EVALUATION\n", + "# dataset : CNMC_no_red\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 submission.py --modelroot \"/home/feoktistovar67431/isbi2019cancer-master/results/model_cnmc_res_450_no_red\" --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_no_red/phase2\" --batch-size 32 --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": 156, + "id": "b6ad9232", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model\n", + "Classifying\n", + "59it [00:52, 1.12it/s] \n", + "Positive: 1266\n", + "Negative: 601\n", + "AUC: 0.8018310900454735\n" + ] + } + ], + "source": [ + "# EVALUATION\n", + "# dataset : CNMC_no_green\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 submission.py --modelroot \"/home/feoktistovar67431/isbi2019cancer-master/results/model_cnmc_res_450_no_green\" --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_no_green/phase2\" --batch-size 32 --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": 157, + "id": "1ba76d51", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model\n", + "Classifying\n", + "59it [00:48, 1.23it/s] \n", + "Positive: 1248\n", + "Negative: 619\n", + "AUC: 0.8570821813062721\n" + ] + } + ], + "source": [ + "# EVALUATION\n", + "# dataset : CNMC_no_blue\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 submission.py --modelroot \"/home/feoktistovar67431/isbi2019cancer-master/results/model_cnmc_res_450_no_blue\" --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_no_blue/phase2\" --batch-size 32 --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": 158, + "id": "05cfaf9c", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model\n", + "Classifying\n", + "59it [00:52, 1.12it/s] \n", + "Positive: 1239\n", + "Negative: 628\n", + "AUC: 0.8013924335875389\n" + ] + } + ], + "source": [ + "# EVALUATION\n", + "# dataset : CNMC_red_only\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 submission.py --modelroot \"/home/feoktistovar67431/isbi2019cancer-master/results/model_cnmc_res_450_red_only\" --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_red_only/phase2\" --batch-size 32 --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": 159, + "id": "1ad09456", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model\n", + "Classifying\n", + "59it [00:52, 1.13it/s] \n", + "Positive: 1221\n", + "Negative: 646\n", + "AUC: 0.8590070792695896\n" + ] + } + ], + "source": [ + "# EVALUATION\n", + "# dataset : CNMC_green_only\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 submission.py --modelroot \"/home/feoktistovar67431/isbi2019cancer-master/results/model_cnmc_res_450_green_only\" --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_green_only/phase2\" --batch-size 32 --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": 160, + "id": "41e8d3a0", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model\n", + "Classifying\n", + "59it [00:52, 1.12it/s] \n", + "Positive: 1255\n", + "Negative: 612\n", + "AUC: 0.8268636253152251\n" + ] + } + ], + "source": [ + "# EVALUATION\n", + "# dataset : CNMC_blue_only\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 submission.py --modelroot \"/home/feoktistovar67431/isbi2019cancer-master/results/model_cnmc_res_450_blue_only\" --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC_blue_only/phase2\" --batch-size 32 --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": 186, + "id": "88bc18db", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model\n", + "Classifying\n", + "59it [01:24, 1.43s/it] \n", + "Positive: 1235\n", + "Negative: 632\n", + "AUC: 0.8588406050294211\n" + ] + } + ], + "source": [ + "# EVALUATION\n", + "# dataset : CNMC-blackborder\n", + "# res : 450\n", + "# epochs : 100\n", + "!python3 submission.py --modelroot \"/home/feoktistovar67431/isbi2019cancer-master/results/model_cnmc_res_450_w_blackborder\" --dataroot \"/home/feoktistovar67431/data/isbi2019/CNMC/phase2\" --batch-size 32 --res 450" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec31125a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/submission.py b/submission.py new file mode 100644 index 0000000..d9cb520 --- /dev/null +++ b/submission.py @@ -0,0 +1,92 @@ +import argparse +import os +import zipfile +from os.path import join + +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +import numpy as np + +from model import get_model +from dataset import get_tf_vaild_rot_transform + +from sklearn import metrics +import matplotlib.pyplot as plt +import csv +from sklearn.metrics import roc_curve, roc_auc_score + +class OrderedImages(Dataset): + def __init__(self, root, transform): + super().__init__() + self.root = root + self.transform = transform + + def __len__(self): + return 1867 + + def __getitem__(self, index): + img = Image.open(os.path.join(self.root, f'{index + 1}.bmp'))#.convert('RGB') + return self.transform(img) + +VALIDATION_ALL = 1219 +VALIDATION_HEM = 648 + +parser = argparse.ArgumentParser() +parser.add_argument('--batch-size', type=int, default=64) +parser.add_argument('--modelroot', default='results/20190313T101236Z.LGJL', help='path to model') +parser.add_argument('--dataroot', default='data/phase3', help='path to dataset') +parser.add_argument('--res', type=int, default='450', help='Desired input resolution') +args = parser.parse_args() + +dataset = OrderedImages(args.dataroot, get_tf_vaild_rot_transform(args.res)) + +print(f"Loading model") +model = get_model().to('cuda:0') +model = torch.nn.DataParallel(model) +model.load_state_dict(torch.load(join(args.modelroot, 'model.pt'))) +model.eval() + +dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=6) + +print("Classifying") +all_labels = [] +for x in tqdm(dataloader, total=len(dataset) // args.batch_size): + x = x.to('cuda:0') + bs, nrot, c, h, w = x.size() + with torch.no_grad(): + y = model(x.view(-1, c, h, w)) + y = y.view(bs, nrot).mean(1) + labels = y > 0 + all_labels.append(labels) + +all_labels = torch.cat(all_labels) +print("Positive:", all_labels.sum().item()) +print("Negative:", len(all_labels) - all_labels.sum().item()) + +file_w = open(r'/home/feoktistovar67431/data/resources/phase2_labels.csv', "r") +true_labels = [] +reader = csv.reader(file_w, delimiter=',') +for row in reader: + true_labels.append(row) + +print(f'AUC: {roc_auc_score(true_labels, all_labels.cpu())}') # Zeige Flaeche unter der Kurve an + + +#print("Accuracy", metrics.accuracy_score(y_test, y_pred)) +#import matplotlib.pyplot as plt +#import numpy as np +#x = # false_positive_rate +#y = # true_positive_rate +# This is the ROC curve +#plt.plot(x,y) +#plt.show() +# This is the AUC +#auc = np.trapz(y,x) + +csv_path = join(args.modelroot, 'submission.csv') +zip_path = join(args.modelroot, 'submission.zip') +np.savetxt(csv_path, all_labels.cpu().numpy(), '%d') +with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: + zipf.write(csv_path, 'isbi_valid.predict') diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..292f2a4 --- /dev/null +++ b/utils.py @@ -0,0 +1,58 @@ +import pickle +import random +import string +from datetime import datetime + +import torch +import torch.nn as nn + + +class IncrementalAverage: + def __init__(self): + self.value = 0 + self.counter = 0 + + def update(self, x): + self.counter += 1 + self.value += (x - self.value) / self.counter + + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +class SizePrinter(nn.Module): + def forward(self, x): + print(x.size()) + return x + + +def count_parameters(model, grad_only=True): + return sum(p.numel() for p in model.parameters() if not grad_only or p.requires_grad) + + +def to_device(device, *tensors): + return tuple(x.to(device) for x in tensors) + + +def loop_iter(iter): + while True: + for item in iter: + yield item + + +def unique_string(): + return '{}.{}'.format(datetime.now().strftime('%Y%m%dT%H%M%SZ'), + ''.join(random.choice(string.ascii_uppercase) for _ in range(4))) + + +def set_seeds(seed): + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def pickle_dump(obj, file): + with open(file, 'wb') as f: + pickle.dump(obj, f)