CSI/main.py

38 lines
1.5 KiB
Python
Raw Permalink Normal View History

2022-04-29 19:26:47 +02:00
from sys import argv
from os import system
from datasets.prepare_data import prep, resize
import torch
import os
from datasets.postprocess_data import postprocess_data
DATA_BASE_DIR = r'/home/feoktistovar67431/CSI/CSI_local/main.py'
BASE_DIR = '/home/feoktistovar67431/CSI/CSI_local/'
def main():
for argument in argv:
if argument == '--proc_step':
proc_step = argv[argv.index(argument)+1]
if proc_step == 'eval':
system("eval.py "+' '.join(argv[1:]))
if proc_step == 'train':
system(BASE_DIR + os.sep + "eval.py " + ' '.join(argv[1:]))
if proc_step == 'plot':
plot_data()
elif proc_step == 'post_proc':
postprocess_data(
[
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm4_one_class_0\log.txt',
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm4_one_class_0_64px\log.txt',
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm16_one_class_0_32px\log.txt',
r'\CNMC_resnet18_unsup_simclr_CSI_shift_cutperm_one_class_0_64px_batch64\log.txt',
r'\CNMC_resnet18_unsup_simclr_CSI_shift_rotation_one_class_0\log.txt',
r"\CNMC_resnet18_unsup_simclr_CSI_shift_gauss_one_class_0_32px\log.txt"
# r'\cifar10_resnet18_unsup_simclr_CSI_shift_rotation_one_class_1\log.txt'
]
)
if __name__ == '__main__':
main()