In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 2.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from copy import deepcopy
  2. import torch
  3. import torch.nn as nn
  4. from torch.utils.data import DataLoader
  5. from common.common import parse_args
  6. import models.classifier as C
  7. from datasets import get_dataset, get_superclass_list, get_subclass_dataset
  8. P = parse_args()
  9. ### Set torch device ###
  10. P.n_gpus = torch.cuda.device_count()
  11. assert P.n_gpus <= 1 # no multi GPU
  12. P.multi_gpu = False
  13. if torch.cuda.is_available():
  14. torch.cuda.set_device(P.local_rank)
  15. device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
  16. ### Initialize dataset ###
  17. ood_eval = P.mode == 'ood_pre'
  18. if P.dataset == 'imagenet' and ood_eval or P.dataset == 'CNMC' and ood_eval or P.dataset == 'CNMC_grayscale' and ood_eval:
  19. P.batch_size = 1
  20. P.test_batch_size = 1
  21. train_set, test_set, image_size, n_classes = get_dataset(P, dataset=P.dataset, eval=ood_eval)
  22. P.image_size = image_size
  23. P.n_classes = n_classes
  24. if P.one_class_idx is not None:
  25. cls_list = get_superclass_list(P.dataset)
  26. P.n_superclasses = len(cls_list)
  27. full_test_set = deepcopy(test_set) # test set of full classes
  28. train_set = get_subclass_dataset(train_set, classes=cls_list[P.one_class_idx])
  29. test_set = get_subclass_dataset(test_set, classes=cls_list[P.one_class_idx])
  30. kwargs = {'pin_memory': False, 'num_workers': 2}
  31. train_loader = DataLoader(train_set, shuffle=True, batch_size=P.batch_size, **kwargs)
  32. test_loader = DataLoader(test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs)
  33. if P.ood_dataset is None:
  34. if P.one_class_idx is not None:
  35. P.ood_dataset = list(range(P.n_superclasses))
  36. P.ood_dataset.pop(P.one_class_idx)
  37. elif P.dataset == 'cifar10':
  38. P.ood_dataset = ['svhn', 'lsun_resize', 'imagenet_resize', 'lsun_fix', 'imagenet_fix', 'cifar100', 'interp']
  39. elif P.dataset == 'imagenet':
  40. P.ood_dataset = ['cub', 'stanford_dogs', 'flowers102', 'places365', 'food_101', 'caltech_256', 'dtd', 'pets']
  41. ood_test_loader = dict()
  42. for ood in P.ood_dataset:
  43. if ood == 'interp':
  44. ood_test_loader[ood] = None # dummy loader
  45. continue
  46. if P.one_class_idx is not None:
  47. ood_test_set = get_subclass_dataset(full_test_set, classes=cls_list[ood])
  48. ood = f'one_class_{ood}' # change save name
  49. else:
  50. ood_test_set = get_dataset(P, dataset=ood, test_only=True, image_size=P.image_size, eval=ood_eval)
  51. ood_test_loader[ood] = DataLoader(ood_test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs)
  52. ### Initialize model ###
  53. simclr_aug = C.get_simclr_augmentation(P, image_size=P.image_size).to(device)
  54. P.shift_trans, P.K_shift = C.get_shift_module(P, eval=True)
  55. P.shift_trans = P.shift_trans.to(device)
  56. model = C.get_classifier(P.model, n_classes=P.n_classes).to(device)
  57. model = C.get_shift_classifer(model, P.K_shift).to(device)
  58. criterion = nn.CrossEntropyLoss().to(device)
  59. if P.load_path is not None:
  60. checkpoint = torch.load(P.load_path)
  61. model.load_state_dict(checkpoint, strict=not P.no_strict)