123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- import torch.nn as nn
-
- from models.resnet import ResNet18, ResNet34, ResNet50
- from models.resnet_imagenet import resnet18, resnet50
- import models.transform_layers as TL
- from torchvision import transforms
-
-
- def get_simclr_augmentation(P, image_size):
- """
- Creates positive data for training.
-
- :param P: parsed arguments
- :param image_size: size of image
- :return: transformation
- """
-
- # parameter for resizecrop
- resize_scale = (P.resize_factor, 1.0) # resize scaling factor
- if P.resize_fix: # if resize_fix is True, use same scale
- resize_scale = (P.resize_factor, P.resize_factor)
-
- # Align augmentation
- s = P.color_distort
- color_jitter = TL.ColorJitterLayer(brightness=s*0.8, contrast=s*0.8, saturation=s*0.8, hue=s*0.2, p=0.8)
- color_gray = TL.RandomColorGrayLayer(p=0.2)
- resize_crop = TL.RandomResizedCropLayer(scale=resize_scale, size=(image_size[0], image_size[1]))
-
- #v_flip = transforms.RandomVerticalFlip()
- #h_flip = transforms.RandomHorizontalFlip()
- rand_aff = transforms.RandomAffine(degrees=360, translate=(0.2, 0.2))
-
- # Transform define #
- if P.dataset == 'imagenet': # Using RandomResizedCrop at PIL transform
- transform = nn.Sequential(
- color_jitter,
- color_gray,
- )
- elif P.dataset == 'CNMC':
- transform = nn.Sequential(
- color_jitter,
- color_gray,
- resize_crop,
- )
- else:
- transform = nn.Sequential(
- color_jitter,
- color_gray,
- resize_crop,
- )
-
- return transform
-
-
- def get_shift_module(P, eval=False):
- """
- Creates shift transformation (negative).
-
- :param P: parsed arguments
- :param eval: whether it is an evaluation step or not
- :return: transformation
- """
- if P.shift_trans_type == 'rotation':
- shift_transform = TL.Rotation()
- K_shift = 4
- elif P.shift_trans_type == 'cutperm':
- shift_transform = TL.CutPerm()
- K_shift = 4
- elif P.shift_trans_type == 'noise':
- shift_transform = TL.GaussNoise(mean=P.noise_mean, std=P.noise_std)
- K_shift = 4
- elif P.shift_trans_type == 'randpers':
- shift_transform = TL.RandPers(distortion_scale=P.distortion_scale, p=1)
- K_shift = 4
- elif P.shift_trans_type == 'sharp':
- shift_transform = TL.RandomAdjustSharpness(sharpness_factor=P.sharpness_factor, p=1)
- K_shift = 4
- elif P.shift_trans_type == 'blur':
- kernel_size = int(int(P.res.replace('px', ''))*0.1)
- if kernel_size%2 == 0:
- kernel_size+=1
- sigma = (0.1, float(P.blur_sigma))
- shift_transform = TL.GaussBlur(kernel_size=kernel_size, sigma=sigma)
- K_shift = 4
- elif P.shift_trans_type == 'blur_randpers':
- kernel_size = int(P.res.replace('px', '')) * 0.1
- sigma = (0.1, float(P.blur_sigma))
- shift_transform = TL.BlurRandpers(kernel_size=kernel_size, sigma=sigma, distortion_scale=P.distortion_scale, p=1)
- K_shift = 4
- elif P.shift_trans_type == 'blur_sharp':
- kernel_size = int(P.res.replace('px', '')) * 0.1
- sigma = (0.1, float(P.blur_sigma))
- shift_transform = TL.BlurSharpness(kernel_size=kernel_size, sigma=sigma, sharpness_factor=P.sharpness_factor, p=1)
- K_shift = 4
- elif P.shift_trans_type == 'randpers_sharp':
- shift_transform = TL.RandpersSharpness(distortion_scale=P.distortion_scale, p=1, sharpness_factor=P.sharpness_factor)
- K_shift = 4
- elif P.shift_trans_type == 'blur_randpers_sharp':
- kernel_size = int(P.res.replace('px', '')) * 0.1
- sigma = (0.1, float(P.blur_sigma))
- shift_transform = TL.BlurRandpersSharpness(kernel_size=kernel_size, sigma=sigma, distortion_scale=P.distortion_scale, p=1, sharpness_factor=P.sharpness_factor)
- K_shift = 4
- else:
- shift_transform = nn.Identity()
- K_shift = 1
-
- if not eval and not ('sup' in P.mode):
- assert P.batch_size == int(128/K_shift)
-
- return shift_transform, K_shift
-
-
- def get_shift_classifer(model, K_shift):
-
- model.shift_cls_layer = nn.Linear(model.last_dim, K_shift)
-
- return model
-
-
- def get_classifier(mode, n_classes=10):
- if mode == 'resnet18':
- classifier = ResNet18(num_classes=n_classes)
- elif mode == 'resnet34':
- classifier = ResNet34(num_classes=n_classes)
- elif mode == 'resnet50':
- classifier = ResNet50(num_classes=n_classes)
- elif mode == 'resnet18_imagenet':
- classifier = resnet18(num_classes=n_classes)
- elif mode == 'resnet50_imagenet':
- classifier = resnet50(num_classes=n_classes)
- else:
- raise NotImplementedError()
-
- return classifier
|