python代码中 for 循环异常中断
if args.deploy == 'finetune':
print("Start selecting the best lr...")
best_acc = 0
for lr in [0, 0.0001, 0.0005, 0.001]:
model.lr = lr
test_stats = evaluate(data_loader_val, model, criterion, device, seed=1234, ep=5) #test_stats:测试结果统计信息
acc = test_stats['acc1']
print(f"*lr = {lr}: acc1 = {acc}")
if acc > best_acc:
best_acc = acc
best_lr = lr
model.lr = best_lr
print(f"### Selected lr = {best_lr}")
在这个代码中第一个学习率0 一切正常,到了第二个学习率 就直接总时间变成0了,没有进行任何操作,直接就回到了打印阶段
def evaluate(data_loaders, model, criterion, device, seed=None, ep=None):
if isinstance(data_loaders, dict):
test_stats_lst = {}
test_stats_glb = {}
for j, (source, data_loader) in enumerate(data_loaders.items()):
print(f'* Evaluating {source}:')
seed_j = seed + j if seed else None
test_stats = _evaluate(data_loader, model, criterion, device, seed_j)
test_stats_lst[source] = test_stats
test_stats_glb[source] = test_stats['acc1']
# apart from individual's acc1, accumulate metrics over all domains to compute mean
for k in test_stats_lst[source].keys():
test_stats_glb[k] = torch.tensor([test_stats[k] for test_stats in test_stats_lst.values()]).mean().item()
return test_stats_glb
elif isinstance(data_loaders, torch.utils.data.DataLoader): # when args.eval = True
return _evaluate(data_loaders, model, criterion, device, seed, ep)
else:
warnings.warn(f'The structure of {data_loaders} is not recognizable.')
return _evaluate(data_loaders, model, criterion, device, seed)
```python
def _evaluate(data_loader, model, criterion, device, seed=None, ep=None):
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('n_ways', utils.SmoothedValue(window_size=1, fmt='{value:d}'))
metric_logger.add_meter('n_imgs', utils.SmoothedValue(window_size=1, fmt='{value:d}'))
metric_logger.add_meter('acc1', utils.SmoothedValue(window_size=len(data_loader.dataset)))
metric_logger.add_meter('acc5', utils.SmoothedValue(window_size=len(data_loader.dataset)))
header = 'Test:'
# switch to evaluation mode
model.eval()
if seed is not None:
data_loader.generator.manual_seed(seed)
for ii, batch in enumerate(metric_logger.log_every(data_loader, 10, header)):
if ep is not None:
if ii > ep:
break
batch = to_device(batch, device)
SupportTensor, SupportLabel, x, y = batch
# compute output
with torch.cuda.amp.autocast():
output = model(SupportTensor, SupportLabel, x) #logits
output = output.view(x.shape[0] * x.shape[1], -1)
y = y.view(-1)
loss = criterion(output, y)
acc1, acc5 = accuracy(output, y, topk=(1, 5))
batch_size = x.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.update(n_ways=SupportLabel.max()+1)
metric_logger.update(n_imgs=SupportTensor.shape[1] + x.shape[1])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
ret_dict = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
ret_dict['acc_std'] = metric_logger.meters['acc1'].std
return ret_dict