In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 1.4KB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839
  1. def setup(mode, P):
  2. fname = f'{P.dataset}_{P.model}_unsup_{mode}_{P.res}'
  3. if mode == 'simclr':
  4. from .simclr import train
  5. elif mode == 'simclr_CSI':
  6. from .simclr_CSI import train
  7. fname += f'_shift_{P.shift_trans_type}_resize_factor{P.resize_factor}_color_dist{P.color_distort}'
  8. if P.shift_trans_type == 'gauss':
  9. fname += f'_gauss_sigma{P.gauss_sigma}'
  10. elif P.shift_trans_type == 'randpers':
  11. fname += f'_distortion_scale{P.distortion_scale}'
  12. elif P.shift_trans_type == 'sharp':
  13. fname += f'_sharpness_factor{P.sharpness_factor}'
  14. elif P.shift_trans_type == 'sharp':
  15. fname += f'_nmean_{P.noise_mean}_nstd_{P.noise_std}'
  16. else:
  17. raise NotImplementedError()
  18. if P.one_class_idx is not None:
  19. fname += f'_one_class_{P.one_class_idx}'
  20. if P.suffix is not None:
  21. fname += f'_{P.suffix}'
  22. return train, fname
  23. def update_comp_loss(loss_dict, loss_in, loss_out, loss_diff, batch_size):
  24. loss_dict['pos'].update(loss_in, batch_size)
  25. loss_dict['neg'].update(loss_out, batch_size)
  26. loss_dict['diff'].update(loss_diff, batch_size)
  27. def summary_comp_loss(logger, tag, loss_dict, epoch):
  28. logger.scalar_summary(f'{tag}/pos', loss_dict['pos'].average, epoch)
  29. logger.scalar_summary(f'{tag}/neg', loss_dict['neg'].average, epoch)
  30. logger.scalar_summary(f'{tag}', loss_dict['diff'].average, epoch)