In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 5.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. from copy import deepcopy
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import torch.optim.lr_scheduler as lr_scheduler
  6. from torch.utils.data import DataLoader
  7. from common.common import parse_args
  8. import models.classifier as C
  9. from datasets import get_dataset, get_superclass_list, get_subclass_dataset
  10. from utils.utils import load_checkpoint
  11. P = parse_args()
  12. ### Set torch device ###
  13. if torch.cuda.is_available():
  14. torch.cuda.set_device(P.local_rank)
  15. device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
  16. P.n_gpus = torch.cuda.device_count()
  17. if P.n_gpus > 1:
  18. import apex
  19. import torch.distributed as dist
  20. from torch.utils.data.distributed import DistributedSampler
  21. P.multi_gpu = True
  22. torch.distributed.init_process_group(
  23. 'nccl',
  24. init_method='env://',
  25. world_size=P.n_gpus,
  26. rank=P.local_rank,
  27. )
  28. else:
  29. P.multi_gpu = False
  30. ### only use one ood_layer while training
  31. P.ood_layer = P.ood_layer[0]
  32. ### Initialize dataset ###
  33. train_set, test_set, image_size, n_classes = get_dataset(P, dataset=P.dataset)
  34. P.image_size = image_size
  35. P.n_classes = n_classes
  36. if P.one_class_idx is not None:
  37. cls_list = get_superclass_list(P.dataset)
  38. P.n_superclasses = len(cls_list)
  39. full_test_set = deepcopy(test_set) # test set of full classes
  40. train_set = get_subclass_dataset(train_set, classes=cls_list[P.one_class_idx])
  41. test_set = get_subclass_dataset(test_set, classes=cls_list[P.one_class_idx])
  42. kwargs = {'pin_memory': False, 'num_workers': 2}
  43. if P.multi_gpu:
  44. train_sampler = DistributedSampler(train_set, num_replicas=P.n_gpus, rank=P.local_rank)
  45. test_sampler = DistributedSampler(test_set, num_replicas=P.n_gpus, rank=P.local_rank)
  46. train_loader = DataLoader(train_set, sampler=train_sampler, batch_size=P.batch_size, **kwargs)
  47. test_loader = DataLoader(test_set, sampler=test_sampler, batch_size=P.test_batch_size, **kwargs)
  48. else:
  49. train_loader = DataLoader(train_set, shuffle=True, batch_size=P.batch_size, **kwargs)
  50. test_loader = DataLoader(test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs)
  51. if P.ood_dataset is None:
  52. if P.one_class_idx is not None:
  53. P.ood_dataset = list(range(P.n_superclasses))
  54. P.ood_dataset.pop(P.one_class_idx)
  55. elif P.dataset == 'cifar10':
  56. P.ood_dataset = ['svhn', 'lsun_resize', 'imagenet_resize', 'lsun_fix', 'imagenet_fix', 'cifar100', 'interp']
  57. elif P.dataset == 'imagenet':
  58. P.ood_dataset = ['cub', 'stanford_dogs', 'flowers102']
  59. ood_test_loader = dict()
  60. for ood in P.ood_dataset:
  61. if ood == 'interp':
  62. ood_test_loader[ood] = None # dummy loader
  63. continue
  64. if P.one_class_idx is not None:
  65. ood_test_set = get_subclass_dataset(full_test_set, classes=cls_list[ood])
  66. ood = f'one_class_{ood}' # change save name
  67. else:
  68. ood_test_set = get_dataset(P, dataset=ood, test_only=True, image_size=P.image_size)
  69. if P.multi_gpu:
  70. ood_sampler = DistributedSampler(ood_test_set, num_replicas=P.n_gpus, rank=P.local_rank)
  71. ood_test_loader[ood] = DataLoader(ood_test_set, sampler=ood_sampler, batch_size=P.test_batch_size, **kwargs)
  72. else:
  73. ood_test_loader[ood] = DataLoader(ood_test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs)
  74. ### Initialize model ###
  75. simclr_aug = C.get_simclr_augmentation(P, image_size=P.image_size).to(device)
  76. P.shift_trans, P.K_shift = C.get_shift_module(P, eval=True)
  77. P.shift_trans = P.shift_trans.to(device)
  78. model = C.get_classifier(P.model, n_classes=P.n_classes).to(device)
  79. model = C.get_shift_classifer(model, P.K_shift).to(device)
  80. criterion = nn.CrossEntropyLoss().to(device)
  81. if P.optimizer == 'sgd':
  82. optimizer = optim.SGD(model.parameters(), lr=P.lr_init, momentum=0.9, weight_decay=P.weight_decay)
  83. lr_decay_gamma = 0.1
  84. elif P.optimizer == 'lars':
  85. from torchlars import LARS
  86. base_optimizer = optim.SGD(model.parameters(), lr=P.lr_init, momentum=0.9, weight_decay=P.weight_decay)
  87. optimizer = LARS(base_optimizer, eps=1e-8, trust_coef=0.001)
  88. lr_decay_gamma = 0.1
  89. else:
  90. raise NotImplementedError()
  91. if P.lr_scheduler == 'cosine':
  92. scheduler = lr_scheduler.CosineAnnealingLR(optimizer, P.epochs)
  93. elif P.lr_scheduler == 'step_decay':
  94. milestones = [int(0.5 * P.epochs), int(0.75 * P.epochs)]
  95. scheduler = lr_scheduler.MultiStepLR(optimizer, gamma=lr_decay_gamma, milestones=milestones)
  96. else:
  97. raise NotImplementedError()
  98. from training.scheduler import GradualWarmupScheduler
  99. scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=10.0, total_epoch=P.warmup, after_scheduler=scheduler)
  100. if P.resume_path is not None:
  101. resume = True
  102. model_state, optim_state, config = load_checkpoint(P.resume_path, mode='last')
  103. model.load_state_dict(model_state, strict=not P.no_strict)
  104. optimizer.load_state_dict(optim_state)
  105. start_epoch = config['epoch']
  106. best = config['best']
  107. error = 100.0
  108. else:
  109. resume = False
  110. start_epoch = 1
  111. best = 100.0
  112. error = 100.0
  113. if P.mode == 'sup_linear' or P.mode == 'sup_CSI_linear':
  114. assert P.load_path is not None
  115. checkpoint = torch.load(P.load_path)
  116. model.load_state_dict(checkpoint, strict=not P.no_strict)
  117. if P.multi_gpu:
  118. simclr_aug = apex.parallel.DistributedDataParallel(simclr_aug, delay_allreduce=True)
  119. model = apex.parallel.convert_syncbn_model(model)
  120. model = apex.parallel.DistributedDataParallel(model, delay_allreduce=True)