代码为
报错信息为
train函数的内容为:
def train():
net.train()
conf_loss = 0
loc_loss = 0
print(" [*] Loading dataset...")
batch_iterator = None
trainset = CUB_loader(os.getcwd() + '/data/CUB_200_2011', split='train')
trainloader = data.DataLoader(trainset, batch_size=2,
shuffle=True, collate_fn=trainset.CUB_collate, num_workers=2)
testset = CUB_loader(os.getcwd() + '/data/CUB_200_2011', split='test')
testloader = data.DataLoader(testset, batch_size=2,
shuffle=False, collate_fn=testset.CUB_collate, num_workers=2)
test_sample, _ = next(iter(testloader))
apn_iter, apn_epoch, apn_steps = pretrainAPN(trainset, trainloader)
cls_iter, cls_epoch, cls_steps = 0, 0, 1
switch_step = 0
old_cls_loss, new_cls_loss = 2, 1
old_apn_loss, new_apn_loss = 2, 1
iteration = 0 # count the both of iteration
epoch_size = len(trainset) // 4
cls_tol = 0
apn_tol = 0
batch_iterator = iter(trainloader)
while ((old_cls_loss - new_cls_loss) ** 2 > 1e-7) and ((old_apn_loss - new_apn_loss) ** 2 > 1e-7) and (
iteration < 500000):
# until the two type of losses no longer change
print(' [*] Swtich optimize parameters to Class')
while ((cls_tol < 10) and (cls_iter % 5000 != 0)):
if (not batch_iterator) or (iteration % epoch_size == 0):
batch_iterator = iter(trainloader)
if cls_iter % epoch_size == 0:
cls_epoch += 1
if cls_epoch in decay_steps:
cls_steps += 1
adjust_learning_rate(opt1, 0.1, cls_steps, args.lr)
old_cls_loss = new_cls_loss
images, labels = next(batch_iterator)
images, labels = Variable(images, requires_grad=True), Variable(labels)
if args.cuda:
images, labels = images.cuda(), labels.cuda()
t0 = time.time()
logits, _, _, _ = net(images)
opt1.zero_grad()
new_cls_losses = multitask_loss(logits, labels)
new_cls_loss = sum(new_cls_losses)
# new_cls_loss = new_cls_losses[0]
new_cls_loss.backward()
opt1.step()
t1 = time.time()
if (old_cls_loss - new_cls_loss) ** 2 < 1e-6:
cls_tol += 1
else:
cls_tol = 0
logger.scalar_summary('cls_loss', new_cls_loss.item(), iteration + 1)
logger.scalar_summary('cls_loss1', new_cls_losses[0].item(), iteration + 1)
logger.scalar_summary('cls_loss12', new_cls_losses[1].item(), iteration + 1)
logger.scalar_summary('cls_loss123', new_cls_losses[2].item(), iteration + 1)
iteration += 1
cls_iter += 1
if (cls_iter % 20) == 0:
print(" [*] cls_epoch[%d], Iter %d || cls_iter %d || cls_loss: %.4f || Timer: %.4fsec" % (
cls_epoch, iteration, cls_iter, new_cls_loss.item(), (t1 - t0)))
images, labels = next(batch_iterator)
if args.cuda:
images, labels = images.cuda(), labels.cuda()
logits, _, _ = net(images)
preds = []
for i in range(len(labels)):
pred = [logit[i][labels[i]] for logit in logits]
preds.append(pred)
new_apn_loss = pairwise_ranking_loss(preds)
logger.scalar_summary('rank_loss', new_apn_loss.item(), iteration + 1)
iteration += 1
# cls_iter += 1
test(testloader, iteration)
# continue
print(' [*] Swtich optimize parameters to APN')
switch_step += 1
while ((apn_tol < 10) and apn_iter % 5000 != 0):
if (not batch_iterator) or (iteration % epoch_size == 0):
batch_iterator = iter(trainloader)
if apn_iter % epoch_size == 0:
apn_epoch += 1
if apn_epoch in decay_steps:
apn_steps += 1
adjust_learning_rate(opt2, 0.1, apn_steps, args.lr)
old_apn_loss = new_apn_loss
images, labels = next(batch_iterator)
images, labels = Variable(images, requires_grad=True), Variable(labels)
if args.cuda:
images, labels = images.cuda(), labels.cuda()
t0 = time.time()
logits, _, _, _ = net(images)
opt2.zero_grad()
preds = []
for i in range(len(labels)):
pred = [logit[i][labels[i]] for logit in logits]
preds.append(pred)
new_apn_loss = pairwise_ranking_loss(preds)
new_apn_loss.backward()
opt2.step()
t1 = time.time()
if (old_apn_loss - new_apn_loss) ** 2 < 1e-6:
apn_tol += 1
else:
apn_tol = 0
logger.scalar_summary('rank_loss', new_apn_loss.item(), iteration + 1)
iteration += 1
apn_iter += 1
if (apn_iter % 20) == 0:
print(" [*] apn_epoch[%d], Iter %d || apn_iter %d || apn_loss: %.4f || Timer: %.4fsec" % (
apn_epoch, iteration, apn_iter, new_apn_loss.item(), (t1 - t0)))
switch_step += 1
images, labels = next(batch_iterator)
if args.cuda:
images, labels = images.cuda(), labels.cuda()
logits, _, _ = net(images)
new_cls_losses = multitask_loss(logits, labels)
new_cls_loss = sum(new_cls_losses)
logger.scalar_summary('cls_loss', new_cls_loss.item(), iteration + 1)
iteration += 1
cls_iter += 1
apn_iter += 1
test(testloader, iteration)
_, _, _, crops = net(test_sample)
x1, x2 = crops[0].data, crops[1].data
# visualize cropped inputs
save_img(x1, path=f'samples/iter_{iteration}@2x.jpg', annotation=f'loss = {avg_loss:.7f}, step = {iteration}')
save_img(x2, path=f'samples/iter_{iteration}@4x.jpg', annotation=f'loss = {avg_loss:.7f}, step = {iteration}')
torch.save(net.state_dict, 'ckpt/RACNN_vgg_CUB200_iter%d.pth' % iteration)