123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- import time
-
- import torch.optim
- import torch.optim.lr_scheduler as lr_scheduler
-
- import models.transform_layers as TL
- from utils.utils import AverageMeter, normalize
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- hflip = TL.HorizontalFlipLayer().to(device)
-
-
- def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None,
- simclr_aug=None, linear=None, linear_optim=None):
-
- if P.multi_gpu:
- rotation_linear = model.module.shift_cls_layer
- joint_linear = model.module.joint_distribution_layer
- else:
- rotation_linear = model.shift_cls_layer
- joint_linear = model.joint_distribution_layer
-
- if epoch == 1:
- # define optimizer and save in P (argument)
- milestones = [int(0.6 * P.epochs), int(0.75 * P.epochs), int(0.9 * P.epochs)]
-
- linear_optim = torch.optim.SGD(linear.parameters(),
- lr=1e-1, weight_decay=P.weight_decay)
- P.linear_optim = linear_optim
- P.linear_scheduler = lr_scheduler.MultiStepLR(P.linear_optim, gamma=0.1, milestones=milestones)
-
- rotation_linear_optim = torch.optim.SGD(rotation_linear.parameters(),
- lr=1e-1, weight_decay=P.weight_decay)
- P.rotation_linear_optim = rotation_linear_optim
- P.rot_scheduler = lr_scheduler.MultiStepLR(P.rotation_linear_optim, gamma=0.1, milestones=milestones)
-
- joint_linear_optim = torch.optim.SGD(joint_linear.parameters(),
- lr=1e-1, weight_decay=P.weight_decay)
- P.joint_linear_optim = joint_linear_optim
- P.joint_scheduler = lr_scheduler.MultiStepLR(P.joint_linear_optim, gamma=0.1, milestones=milestones)
-
- if logger is None:
- log_ = print
- else:
- log_ = logger.log
-
- batch_time = AverageMeter()
- data_time = AverageMeter()
-
- losses = dict()
- losses['cls'] = AverageMeter()
- losses['rot'] = AverageMeter()
-
- check = time.time()
- for n, (images, labels) in enumerate(loader):
- model.eval()
- count = n * P.n_gpus # number of trained samples
-
- data_time.update(time.time() - check)
- check = time.time()
-
- ### SimCLR loss ###
- if P.dataset != 'imagenet':
- batch_size = images.size(0)
- images = images.to(device)
- images = hflip(images) # 2B with hflip
- else:
- batch_size = images[0].size(0)
- images = images[0].to(device)
-
- labels = labels.to(device)
- images = torch.cat([torch.rot90(images, rot, (2, 3)) for rot in range(4)]) # 4B
- rot_labels = torch.cat([torch.ones_like(labels) * k for k in range(4)], 0) # B -> 4B
- joint_labels = torch.cat([labels + P.n_classes * i for i in range(4)], dim=0)
-
- images = simclr_aug(images) # simclr augmentation
- _, outputs_aux = model(images, penultimate=True)
- penultimate = outputs_aux['penultimate'].detach()
-
- outputs = linear(penultimate[0:batch_size]) # only use 0 degree samples for linear eval
- outputs_rot = rotation_linear(penultimate)
- outputs_joint = joint_linear(penultimate)
-
- loss_ce = criterion(outputs, labels)
- loss_rot = criterion(outputs_rot, rot_labels)
- loss_joint = criterion(outputs_joint, joint_labels)
-
- ### CE loss ###
- P.linear_optim.zero_grad()
- loss_ce.backward()
- P.linear_optim.step()
-
- ### Rot loss ###
- P.rotation_linear_optim.zero_grad()
- loss_rot.backward()
- P.rotation_linear_optim.step()
-
- ### Joint loss ###
- P.joint_linear_optim.zero_grad()
- loss_joint.backward()
- P.joint_linear_optim.step()
-
- ### optimizer learning rate ###
- lr = P.linear_optim.param_groups[0]['lr']
-
- batch_time.update(time.time() - check)
-
- ### Log losses ###
- losses['cls'].update(loss_ce.item(), batch_size)
- losses['rot'].update(loss_rot.item(), batch_size)
-
- if count % 50 == 0:
- log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n'
- '[LossC %f] [LossR %f]' %
- (epoch, count, batch_time.value, data_time.value, lr,
- losses['cls'].value, losses['rot'].value))
- check = time.time()
-
- P.linear_scheduler.step()
- P.rot_scheduler.step()
- P.joint_scheduler.step()
-
- log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f] [LossR %f]' %
- (batch_time.average, data_time.average,
- losses['cls'].average, losses['rot'].average))
-
- if logger is not None:
- logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch)
- logger.scalar_summary('train/loss_rot', losses['rot'].average, epoch)
- logger.scalar_summary('train/batch_time', batch_time.average, epoch)
|