def setup(mode, P): fname = f'{P.dataset}_{P.model}_unsup_{mode}_{P.res}' if mode == 'simclr': from .simclr import train elif mode == 'simclr_CSI': from .simclr_CSI import train fname += f'_shift_{P.shift_trans_type}_resize_factor{P.resize_factor}_color_dist{P.color_distort}' if P.shift_trans_type == 'gauss': fname += f'_gauss_sigma{P.gauss_sigma}' elif P.shift_trans_type == 'randpers': fname += f'_distortion_scale{P.distortion_scale}' elif P.shift_trans_type == 'sharp': fname += f'_sharpness_factor{P.sharpness_factor}' elif P.shift_trans_type == 'sharp': fname += f'_nmean_{P.noise_mean}_nstd_{P.noise_std}' else: raise NotImplementedError() if P.one_class_idx is not None: fname += f'_one_class_{P.one_class_idx}' if P.suffix is not None: fname += f'_{P.suffix}' return train, fname def update_comp_loss(loss_dict, loss_in, loss_out, loss_diff, batch_size): loss_dict['pos'].update(loss_in, batch_size) loss_dict['neg'].update(loss_out, batch_size) loss_dict['diff'].update(loss_diff, batch_size) def summary_comp_loss(logger, tag, loss_dict, epoch): logger.scalar_summary(f'{tag}/pos', loss_dict['pos'].average, epoch) logger.scalar_summary(f'{tag}/neg', loss_dict['neg'].average, epoch) logger.scalar_summary(f'{tag}', loss_dict['diff'].average, epoch)