123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- import os
- import time
- import random
-
- import cv2
- import numpy as np
- import torch
-
- import torch.nn.functional as F
- from torchvision import datasets, transforms
- from torch.utils.data import DataLoader
- from torchvision.utils import save_image
-
- from datasets import get_subclass_dataset
-
- def set_random_seed(seed):
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
-
- IMAGENET_PATH = '~/data/ImageNet'
-
-
- check = time.time()
-
- transform = transforms.Compose([
- transforms.Resize(256),
- transforms.CenterCrop(256),
- transforms.Resize(32),
- transforms.ToTensor(),
- ])
-
- # remove airliner(1), ambulance(2), parking_meter(18), schooner(22) since similar class exist in CIFAR-10
- class_idx_list = list(range(30))
- remove_idx_list = [1, 2, 18, 22]
- for remove_idx in remove_idx_list:
- class_idx_list.remove(remove_idx)
-
- set_random_seed(0)
- train_dir = os.path.join(IMAGENET_PATH, 'one_class_train')
- Imagenet_set = datasets.ImageFolder(train_dir, transform=transform)
- Imagenet_set = get_subclass_dataset(Imagenet_set, class_idx_list)
- Imagenet_dataloader = DataLoader(Imagenet_set, batch_size=100, shuffle=True, pin_memory=False)
-
- total_test_image = None
- for n, (test_image, target) in enumerate(Imagenet_dataloader):
-
- if n == 0:
- total_test_image = test_image
- else:
- total_test_image = torch.cat((total_test_image, test_image), dim=0).cpu()
-
- if total_test_image.size(0) >= 10000:
- break
-
- print (f'Preprocessing time {time.time()-check}')
-
- if not os.path.exists('./Imagenet_fix'):
- os.mkdir('./Imagenet_fix')
-
- check = time.time()
- for i in range(10000):
- save_image(total_test_image[i], f'Imagenet_fix/correct_resize_{i}.png')
- print (f'Saving time {time.time()-check}')
|