In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.py 6.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. from argparse import ArgumentParser
  2. def parse_args(default=False):
  3. """Command-line argument parser for training."""
  4. parser = ArgumentParser(description='Pytorch implementation of CSI')
  5. parser.add_argument('--dataset', help='Dataset',
  6. choices=['cifar10', 'cifar100', 'imagenet', 'CNMC', 'CNMC_grayscale'], type=str)
  7. parser.add_argument('--one_class_idx', help='None: multi-class, Not None: one-class',
  8. default=None, type=int)
  9. parser.add_argument('--model', help='Model',
  10. choices=['resnet18', 'resnet18_imagenet'], type=str)
  11. parser.add_argument('--mode', help='Training mode',
  12. default='simclr', type=str)
  13. parser.add_argument('--simclr_dim', help='Dimension of simclr layer',
  14. default=128, type=int)
  15. parser.add_argument('--shift_trans_type', help='shifting transformation type', default='none',
  16. choices=['rotation', 'cutperm', 'blur', 'randpers', 'sharp', 'blur_randpers',
  17. 'blur_sharp', 'randpers_sharp', 'blur_randpers_sharp', 'noise', 'none'], type=str)
  18. parser.add_argument("--local_rank", type=int,
  19. default=0, help='Local rank for distributed learning')
  20. parser.add_argument('--resume_path', help='Path to the resume checkpoint',
  21. default=None, type=str)
  22. parser.add_argument('--load_path', help='Path to the loading checkpoint',
  23. default=None, type=str)
  24. parser.add_argument("--no_strict", help='Do not strictly load state_dicts',
  25. action='store_true')
  26. parser.add_argument('--suffix', help='Suffix for the log dir',
  27. default=None, type=str)
  28. parser.add_argument('--error_step', help='Epoch steps to compute errors',
  29. default=5, type=int)
  30. parser.add_argument('--save_step', help='Epoch steps to save models',
  31. default=10, type=int)
  32. ##### Training Configurations #####
  33. parser.add_argument('--epochs', help='Epochs',
  34. default=1000, type=int)
  35. parser.add_argument('--optimizer', help='Optimizer',
  36. choices=['sgd', 'lars'],
  37. default='lars', type=str)
  38. parser.add_argument('--lr_scheduler', help='Learning rate scheduler',
  39. choices=['step_decay', 'cosine'],
  40. default='cosine', type=str)
  41. parser.add_argument('--warmup', help='Warm-up epochs',
  42. default=10, type=int)
  43. parser.add_argument('--lr_init', help='Initial learning rate',
  44. default=1e-1, type=float)
  45. parser.add_argument('--weight_decay', help='Weight decay',
  46. default=1e-6, type=float)
  47. parser.add_argument('--batch_size', help='Batch size',
  48. default=128, type=int)
  49. parser.add_argument('--test_batch_size', help='Batch size for test loader',
  50. default=100, type=int)
  51. parser.add_argument('--blur_sigma', help='Distortion grade',
  52. default=2.0, type=float)
  53. parser.add_argument('--color_distort', help='Color distortion grade',
  54. default=0.5, type=float)
  55. parser.add_argument('--distortion_scale', help='Perspective distortion grade',
  56. default=0.6, type=float)
  57. parser.add_argument('--sharpness_factor', help='Sharpening or blurring factor of image. '
  58. 'Can be any non negative number. 0 gives a blurred image, '
  59. '1 gives the original image while 2 increases the sharpness '
  60. 'by a factor of 2.',
  61. default=2, type=float)
  62. parser.add_argument('--noise_mean', help='mean',
  63. default=0, type=float)
  64. parser.add_argument('--noise_std', help='std',
  65. default=0.3, type=float)
  66. ##### Objective Configurations #####
  67. parser.add_argument('--sim_lambda', help='Weight for SimCLR loss',
  68. default=1.0, type=float)
  69. parser.add_argument('--temperature', help='Temperature for similarity',
  70. default=0.5, type=float)
  71. ##### Evaluation Configurations #####
  72. parser.add_argument("--ood_dataset", help='Datasets for OOD detection',
  73. default=None, nargs="*", type=str)
  74. parser.add_argument("--ood_score", help='score function for OOD detection',
  75. default=['norm_mean'], nargs="+", type=str)
  76. parser.add_argument("--ood_layer", help='layer for OOD scores',
  77. choices=['penultimate', 'simclr', 'shift'],
  78. default=['simclr', 'shift'], nargs="+", type=str)
  79. parser.add_argument("--ood_samples", help='number of samples to compute OOD score',
  80. default=1, type=int)
  81. parser.add_argument("--ood_batch_size", help='batch size to compute OOD score',
  82. default=100, type=int)
  83. parser.add_argument("--resize_factor", help='resize scale is sampled from [resize_factor, 1.0]',
  84. default=0.08, type=float)
  85. parser.add_argument("--resize_fix", help='resize scale is fixed to resize_factor (not (resize_factor, 1.0])',
  86. action='store_true')
  87. parser.add_argument("--print_score", help='print quantiles of ood score',
  88. action='store_true')
  89. parser.add_argument("--save_score", help='save ood score for plotting histogram',
  90. action='store_true')
  91. ##### Process configuration option #####
  92. parser.add_argument("--proc_step", help='choose process to initiate.',
  93. choices=['eval', 'train'],
  94. default=None, type=str)
  95. parser.add_argument("--res", help='resolution of dataset',
  96. default="32px", type=str)
  97. if default:
  98. return parser.parse_args('') # empty string
  99. else:
  100. return parser.parse_args()