12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- import os
- import time
- import random
-
- import numpy as np
- import torch
-
- from torchvision import datasets, transforms
- from torch.utils.data import DataLoader
- from torchvision.utils import save_image
-
- def set_random_seed(seed):
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
-
- check = time.time()
-
- transform = transforms.Compose([
- transforms.Resize(256),
- transforms.CenterCrop(256),
- transforms.Resize(32),
- transforms.ToTensor(),
- ])
-
- set_random_seed(0)
-
- LSUN_class_list = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
- 'conference_room', 'dining_room', 'kitchen', 'living_room', 'restaurant', 'tower']
-
- total_test_image_all_class = []
- for LSUN_class in LSUN_class_list:
- LSUN_set = datasets.LSUN('~/data/lsun/', classes=LSUN_class + '_train', transform=transform)
- LSUN_loader = DataLoader(LSUN_set, batch_size=100, shuffle=True, pin_memory=False)
-
- total_test_image = None
- for n, (test_image, _) in enumerate(LSUN_loader):
-
- if n == 0:
- total_test_image = test_image
- else:
- total_test_image = torch.cat((total_test_image, test_image), dim=0).cpu()
-
- if total_test_image.size(0) >= 1000:
- break
-
- total_test_image_all_class.append(total_test_image)
-
- total_test_image_all_class = torch.cat(total_test_image_all_class, dim=0)
-
- print (f'Preprocessing time {time.time()-check}')
-
- if not os.path.exists('./LSUN_fix'):
- os.mkdir('./LSUN_fix')
-
- check = time.time()
- for i in range(10000):
- save_image(total_test_image_all_class[i], f'LSUN_fix/correct_resize_{i}.png')
- print (f'Saving time {time.time()-check}')
|