|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- from argparse import ArgumentParser
-
-
- def parse_args(default=False):
- """Command-line argument parser for training."""
-
- parser = ArgumentParser(description='Pytorch implementation of CSI')
-
- parser.add_argument('--dataset', help='Dataset',
- choices=['cifar10', 'cifar100', 'imagenet', 'CNMC', 'CNMC_grayscale'], type=str)
- parser.add_argument('--one_class_idx', help='None: multi-class, Not None: one-class',
- default=None, type=int)
- parser.add_argument('--model', help='Model',
- choices=['resnet18', 'resnet18_imagenet'], type=str)
- parser.add_argument('--mode', help='Training mode',
- default='simclr', type=str)
- parser.add_argument('--simclr_dim', help='Dimension of simclr layer',
- default=128, type=int)
-
- parser.add_argument('--shift_trans_type', help='shifting transformation type', default='none',
- choices=['rotation', 'cutperm', 'blur', 'randpers', 'sharp', 'blur_randpers',
- 'blur_sharp', 'randpers_sharp', 'blur_randpers_sharp', 'noise', 'none'], type=str)
-
- parser.add_argument("--local_rank", type=int,
- default=0, help='Local rank for distributed learning')
- parser.add_argument('--resume_path', help='Path to the resume checkpoint',
- default=None, type=str)
- parser.add_argument('--load_path', help='Path to the loading checkpoint',
- default=None, type=str)
- parser.add_argument("--no_strict", help='Do not strictly load state_dicts',
- action='store_true')
- parser.add_argument('--suffix', help='Suffix for the log dir',
- default=None, type=str)
- parser.add_argument('--error_step', help='Epoch steps to compute errors',
- default=5, type=int)
- parser.add_argument('--save_step', help='Epoch steps to save models',
- default=10, type=int)
-
- ##### Training Configurations #####
- parser.add_argument('--epochs', help='Epochs',
- default=1000, type=int)
- parser.add_argument('--optimizer', help='Optimizer',
- choices=['sgd', 'lars'],
- default='lars', type=str)
- parser.add_argument('--lr_scheduler', help='Learning rate scheduler',
- choices=['step_decay', 'cosine'],
- default='cosine', type=str)
- parser.add_argument('--warmup', help='Warm-up epochs',
- default=10, type=int)
- parser.add_argument('--lr_init', help='Initial learning rate',
- default=1e-1, type=float)
- parser.add_argument('--weight_decay', help='Weight decay',
- default=1e-6, type=float)
- parser.add_argument('--batch_size', help='Batch size',
- default=128, type=int)
- parser.add_argument('--test_batch_size', help='Batch size for test loader',
- default=100, type=int)
-
- parser.add_argument('--blur_sigma', help='Distortion grade',
- default=2.0, type=float)
- parser.add_argument('--color_distort', help='Color distortion grade',
- default=0.5, type=float)
- parser.add_argument('--distortion_scale', help='Perspective distortion grade',
- default=0.6, type=float)
- parser.add_argument('--sharpness_factor', help='Sharpening or blurring factor of image. '
- 'Can be any non negative number. 0 gives a blurred image, '
- '1 gives the original image while 2 increases the sharpness '
- 'by a factor of 2.',
- default=2, type=float)
- parser.add_argument('--noise_mean', help='mean',
- default=0, type=float)
- parser.add_argument('--noise_std', help='std',
- default=0.3, type=float)
-
-
- ##### Objective Configurations #####
- parser.add_argument('--sim_lambda', help='Weight for SimCLR loss',
- default=1.0, type=float)
- parser.add_argument('--temperature', help='Temperature for similarity',
- default=0.5, type=float)
-
- ##### Evaluation Configurations #####
- parser.add_argument("--ood_dataset", help='Datasets for OOD detection',
- default=None, nargs="*", type=str)
- parser.add_argument("--ood_score", help='score function for OOD detection',
- default=['norm_mean'], nargs="+", type=str)
- parser.add_argument("--ood_layer", help='layer for OOD scores',
- choices=['penultimate', 'simclr', 'shift'],
- default=['simclr', 'shift'], nargs="+", type=str)
- parser.add_argument("--ood_samples", help='number of samples to compute OOD score',
- default=1, type=int)
- parser.add_argument("--ood_batch_size", help='batch size to compute OOD score',
- default=100, type=int)
- parser.add_argument("--resize_factor", help='resize scale is sampled from [resize_factor, 1.0]',
- default=0.08, type=float)
- parser.add_argument("--resize_fix", help='resize scale is fixed to resize_factor (not (resize_factor, 1.0])',
- action='store_true')
-
- parser.add_argument("--print_score", help='print quantiles of ood score',
- action='store_true')
- parser.add_argument("--save_score", help='save ood score for plotting histogram',
- action='store_true')
-
- ##### Process configuration option #####
- parser.add_argument("--proc_step", help='choose process to initiate.',
- choices=['eval', 'train'],
- default=None, type=str)
- parser.add_argument("--res", help='resolution of dataset',
- default="32px", type=str)
-
- if default:
- return parser.parse_args('') # empty string
- else:
- return parser.parse_args()
|