import torch.nn as nn from models.resnet import ResNet18, ResNet34, ResNet50 from models.resnet_imagenet import resnet18, resnet50 import models.transform_layers as TL from torchvision import transforms def get_simclr_augmentation(P, image_size): """ Creates positive data for training. :param P: parsed arguments :param image_size: size of image :return: transformation """ # parameter for resizecrop resize_scale = (P.resize_factor, 1.0) # resize scaling factor if P.resize_fix: # if resize_fix is True, use same scale resize_scale = (P.resize_factor, P.resize_factor) # Align augmentation s = P.color_distort color_jitter = TL.ColorJitterLayer(brightness=s*0.8, contrast=s*0.8, saturation=s*0.8, hue=s*0.2, p=0.8) color_gray = TL.RandomColorGrayLayer(p=0.2) resize_crop = TL.RandomResizedCropLayer(scale=resize_scale, size=(image_size[0], image_size[1])) #v_flip = transforms.RandomVerticalFlip() #h_flip = transforms.RandomHorizontalFlip() rand_aff = transforms.RandomAffine(degrees=360, translate=(0.2, 0.2)) # Transform define # if P.dataset == 'imagenet': # Using RandomResizedCrop at PIL transform transform = nn.Sequential( color_jitter, color_gray, ) elif P.dataset == 'CNMC': transform = nn.Sequential( color_jitter, color_gray, resize_crop, ) else: transform = nn.Sequential( color_jitter, color_gray, resize_crop, ) return transform def get_shift_module(P, eval=False): """ Creates shift transformation (negative). :param P: parsed arguments :param eval: whether it is an evaluation step or not :return: transformation """ if P.shift_trans_type == 'rotation': shift_transform = TL.Rotation() K_shift = 4 elif P.shift_trans_type == 'cutperm': shift_transform = TL.CutPerm() K_shift = 4 elif P.shift_trans_type == 'noise': shift_transform = TL.GaussNoise(mean=P.noise_mean, std=P.noise_std) K_shift = 4 elif P.shift_trans_type == 'randpers': shift_transform = TL.RandPers(distortion_scale=P.distortion_scale, p=1) K_shift = 4 elif P.shift_trans_type == 'sharp': shift_transform = TL.RandomAdjustSharpness(sharpness_factor=P.sharpness_factor, p=1) K_shift = 4 elif P.shift_trans_type == 'blur': kernel_size = int(int(P.res.replace('px', ''))*0.1) if kernel_size%2 == 0: kernel_size+=1 sigma = (0.1, float(P.blur_sigma)) shift_transform = TL.GaussBlur(kernel_size=kernel_size, sigma=sigma) K_shift = 4 elif P.shift_trans_type == 'blur_randpers': kernel_size = int(P.res.replace('px', '')) * 0.1 sigma = (0.1, float(P.blur_sigma)) shift_transform = TL.BlurRandpers(kernel_size=kernel_size, sigma=sigma, distortion_scale=P.distortion_scale, p=1) K_shift = 4 elif P.shift_trans_type == 'blur_sharp': kernel_size = int(P.res.replace('px', '')) * 0.1 sigma = (0.1, float(P.blur_sigma)) shift_transform = TL.BlurSharpness(kernel_size=kernel_size, sigma=sigma, sharpness_factor=P.sharpness_factor, p=1) K_shift = 4 elif P.shift_trans_type == 'randpers_sharp': shift_transform = TL.RandpersSharpness(distortion_scale=P.distortion_scale, p=1, sharpness_factor=P.sharpness_factor) K_shift = 4 elif P.shift_trans_type == 'blur_randpers_sharp': kernel_size = int(P.res.replace('px', '')) * 0.1 sigma = (0.1, float(P.blur_sigma)) shift_transform = TL.BlurRandpersSharpness(kernel_size=kernel_size, sigma=sigma, distortion_scale=P.distortion_scale, p=1, sharpness_factor=P.sharpness_factor) K_shift = 4 else: shift_transform = nn.Identity() K_shift = 1 if not eval and not ('sup' in P.mode): assert P.batch_size == int(128/K_shift) return shift_transform, K_shift def get_shift_classifer(model, K_shift): model.shift_cls_layer = nn.Linear(model.last_dim, K_shift) return model def get_classifier(mode, n_classes=10): if mode == 'resnet18': classifier = ResNet18(num_classes=n_classes) elif mode == 'resnet34': classifier = ResNet34(num_classes=n_classes) elif mode == 'resnet50': classifier = ResNet50(num_classes=n_classes) elif mode == 'resnet18_imagenet': classifier = resnet18(num_classes=n_classes) elif mode == 'resnet50_imagenet': classifier = resnet50(num_classes=n_classes) else: raise NotImplementedError() return classifier