In Masterarbeit:"Anomalie-Detektion in Zellbildern zur Anwendung der Leukämieerkennung" verwendete CSI Methode.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_model.py 1.3KB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from abc import *
  2. import torch.nn as nn
  3. class BaseModel(nn.Module, metaclass=ABCMeta):
  4. def __init__(self, last_dim, num_classes=10, simclr_dim=128):
  5. super(BaseModel, self).__init__()
  6. self.linear = nn.Linear(last_dim, num_classes)
  7. self.simclr_layer = nn.Sequential(
  8. nn.Linear(last_dim, last_dim),
  9. nn.ReLU(),
  10. nn.Linear(last_dim, simclr_dim),
  11. )
  12. self.shift_cls_layer = nn.Linear(last_dim, 2)
  13. self.joint_distribution_layer = nn.Linear(last_dim, 4 * num_classes)
  14. @abstractmethod
  15. def penultimate(self, inputs, all_features=False):
  16. pass
  17. def forward(self, inputs, penultimate=False, simclr=False, shift=False, joint=False):
  18. _aux = {}
  19. _return_aux = False
  20. features = self.penultimate(inputs)
  21. output = self.linear(features)
  22. if penultimate:
  23. _return_aux = True
  24. _aux['penultimate'] = features
  25. if simclr:
  26. _return_aux = True
  27. _aux['simclr'] = self.simclr_layer(features)
  28. if shift:
  29. _return_aux = True
  30. _aux['shift'] = self.shift_cls_layer(features)
  31. if joint:
  32. _return_aux = True
  33. _aux['joint'] = self.joint_distribution_layer(features)
  34. if _return_aux:
  35. return output, _aux
  36. return output