In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long. 2.9KB

  1. from copy import deepcopy
  2. import torch
  3. import torch.nn as nn
  4. from import DataLoader
  5. from common.common import parse_args
  6. import models.classifier as C
  7. from datasets import get_dataset, get_superclass_list, get_subclass_dataset
  8. P = parse_args()
  9. ### Set torch device ###
  10. P.n_gpus = torch.cuda.device_count()
  11. assert P.n_gpus <= 1 # no multi GPU
  12. P.multi_gpu = False
  13. if torch.cuda.is_available():
  14. torch.cuda.set_device(P.local_rank)
  15. device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
  16. ### Initialize dataset ###
  17. ood_eval = P.mode == 'ood_pre'
  18. if P.dataset == 'imagenet' and ood_eval or P.dataset == 'CNMC' and ood_eval or P.dataset == 'CNMC_grayscale' and ood_eval:
  19. P.batch_size = 1
  20. P.test_batch_size = 1
  21. train_set, test_set, image_size, n_classes = get_dataset(P, dataset=P.dataset, eval=ood_eval)
  22. P.image_size = image_size
  23. P.n_classes = n_classes
  24. if P.one_class_idx is not None:
  25. cls_list = get_superclass_list(P.dataset)
  26. P.n_superclasses = len(cls_list)
  27. full_test_set = deepcopy(test_set) # test set of full classes
  28. train_set = get_subclass_dataset(train_set, classes=cls_list[P.one_class_idx])
  29. test_set = get_subclass_dataset(test_set, classes=cls_list[P.one_class_idx])
  30. kwargs = {'pin_memory': False, 'num_workers': 2}
  31. train_loader = DataLoader(train_set, shuffle=True, batch_size=P.batch_size, **kwargs)
  32. test_loader = DataLoader(test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs)
  33. if P.ood_dataset is None:
  34. if P.one_class_idx is not None:
  35. P.ood_dataset = list(range(P.n_superclasses))
  36. P.ood_dataset.pop(P.one_class_idx)
  37. elif P.dataset == 'cifar10':
  38. P.ood_dataset = ['svhn', 'lsun_resize', 'imagenet_resize', 'lsun_fix', 'imagenet_fix', 'cifar100', 'interp']
  39. elif P.dataset == 'imagenet':
  40. P.ood_dataset = ['cub', 'stanford_dogs', 'flowers102', 'places365', 'food_101', 'caltech_256', 'dtd', 'pets']
  41. ood_test_loader = dict()
  42. for ood in P.ood_dataset:
  43. if ood == 'interp':
  44. ood_test_loader[ood] = None # dummy loader
  45. continue
  46. if P.one_class_idx is not None:
  47. ood_test_set = get_subclass_dataset(full_test_set, classes=cls_list[ood])
  48. ood = f'one_class_{ood}' # change save name
  49. else:
  50. ood_test_set = get_dataset(P, dataset=ood, test_only=True, image_size=P.image_size, eval=ood_eval)
  51. ood_test_loader[ood] = DataLoader(ood_test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs)
  52. ### Initialize model ###
  53. simclr_aug = C.get_simclr_augmentation(P, image_size=P.image_size).to(device)
  54. P.shift_trans, P.K_shift = C.get_shift_module(P, eval=True)
  55. P.shift_trans =
  56. model = C.get_classifier(P.model, n_classes=P.n_classes).to(device)
  57. model = C.get_shift_classifer(model, P.K_shift).to(device)
  58. criterion = nn.CrossEntropyLoss().to(device)
  59. if P.load_path is not None:
  60. checkpoint = torch.load(P.load_path)
  61. model.load_state_dict(checkpoint, strict=not P.no_strict)