In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

imagenet_fix_preprocess.py 1.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import os
  2. import time
  3. import random
  4. import cv2
  5. import numpy as np
  6. import torch
  7. import torch.nn.functional as F
  8. from torchvision import datasets, transforms
  9. from torch.utils.data import DataLoader
  10. from torchvision.utils import save_image
  11. from datasets import get_subclass_dataset
  12. def set_random_seed(seed):
  13. random.seed(seed)
  14. np.random.seed(seed)
  15. torch.manual_seed(seed)
  16. torch.cuda.manual_seed(seed)
  17. IMAGENET_PATH = '~/data/ImageNet'
  18. check = time.time()
  19. transform = transforms.Compose([
  20. transforms.Resize(256),
  21. transforms.CenterCrop(256),
  22. transforms.Resize(32),
  23. transforms.ToTensor(),
  24. ])
  25. # remove airliner(1), ambulance(2), parking_meter(18), schooner(22) since similar class exist in CIFAR-10
  26. class_idx_list = list(range(30))
  27. remove_idx_list = [1, 2, 18, 22]
  28. for remove_idx in remove_idx_list:
  29. class_idx_list.remove(remove_idx)
  30. set_random_seed(0)
  31. train_dir = os.path.join(IMAGENET_PATH, 'one_class_train')
  32. Imagenet_set = datasets.ImageFolder(train_dir, transform=transform)
  33. Imagenet_set = get_subclass_dataset(Imagenet_set, class_idx_list)
  34. Imagenet_dataloader = DataLoader(Imagenet_set, batch_size=100, shuffle=True, pin_memory=False)
  35. total_test_image = None
  36. for n, (test_image, target) in enumerate(Imagenet_dataloader):
  37. if n == 0:
  38. total_test_image = test_image
  39. else:
  40. total_test_image = torch.cat((total_test_image, test_image), dim=0).cpu()
  41. if total_test_image.size(0) >= 10000:
  42. break
  43. print (f'Preprocessing time {time.time()-check}')
  44. if not os.path.exists('./Imagenet_fix'):
  45. os.mkdir('./Imagenet_fix')
  46. check = time.time()
  47. for i in range(10000):
  48. save_image(total_test_image[i], f'Imagenet_fix/correct_resize_{i}.png')
  49. print (f'Saving time {time.time()-check}')