58 lines
1.7 KiB
Python
58 lines
1.7 KiB
Python
from utils.utils import Logger
|
|
from utils.utils import save_checkpoint
|
|
from utils.utils import save_linear_checkpoint
|
|
|
|
from common.train import *
|
|
from evals import test_classifier
|
|
|
|
if 'sup' in P.mode:
|
|
from training.sup import setup
|
|
else:
|
|
from training.unsup import setup
|
|
train, fname = setup(P.mode, P)
|
|
|
|
logger = Logger(fname, ask=not resume, local_rank=P.local_rank)
|
|
logger.log(P)
|
|
logger.log(model)
|
|
|
|
if P.multi_gpu:
|
|
linear = model.module.linear
|
|
else:
|
|
linear = model.linear
|
|
linear_optim = torch.optim.Adam(linear.parameters(), lr=1e-3, betas=(.9, .999), weight_decay=P.weight_decay)
|
|
|
|
# Run experiments
|
|
for epoch in range(start_epoch, P.epochs + 1):
|
|
logger.log_dirname(f"Epoch {epoch}")
|
|
model.train()
|
|
|
|
if P.multi_gpu:
|
|
train_sampler.set_epoch(epoch)
|
|
|
|
kwargs = {}
|
|
kwargs['linear'] = linear
|
|
kwargs['linear_optim'] = linear_optim
|
|
kwargs['simclr_aug'] = simclr_aug
|
|
|
|
train(P, epoch, model, criterion, optimizer, scheduler_warmup, train_loader, logger=logger, **kwargs)
|
|
|
|
model.eval()
|
|
|
|
if epoch % P.save_step == 0 and P.local_rank == 0:
|
|
if P.multi_gpu:
|
|
save_states = model.module.state_dict()
|
|
else:
|
|
save_states = model.state_dict()
|
|
save_checkpoint(epoch, save_states, optimizer.state_dict(), logger.logdir)
|
|
save_linear_checkpoint(linear_optim.state_dict(), logger.logdir)
|
|
|
|
if epoch % P.error_step == 0 and ('sup' in P.mode):
|
|
error = test_classifier(P, model, test_loader, epoch, logger=logger)
|
|
|
|
is_best = (best > error)
|
|
if is_best:
|
|
best = error
|
|
|
|
logger.scalar_summary('eval/best_error', best, epoch)
|
|
logger.log('[Epoch %3d] [Test %5.2f] [Best %5.2f]' % (epoch, error, best))
|