DD3H 2023-06-10 14:41 采纳率: 0%
浏览 147
已结题

python代码中 for 循环异常中断

python代码中 for 循环异常中断

if args.deploy == 'finetune':
            print("Start selecting the best lr...")
            best_acc = 0
            for lr in [0, 0.0001, 0.0005, 0.001]:
                model.lr = lr
                test_stats = evaluate(data_loader_val, model, criterion, device, seed=1234, ep=5) #test_stats:测试结果统计信息
                acc = test_stats['acc1']
                print(f"*lr = {lr}: acc1 = {acc}")
                if acc > best_acc:
                    best_acc = acc
                    best_lr = lr
            model.lr = best_lr
            print(f"### Selected lr = {best_lr}")

在这个代码中第一个学习率0 一切正常,到了第二个学习率 就直接总时间变成0了,没有进行任何操作,直接就回到了打印阶段

img

def evaluate(data_loaders, model, criterion, device, seed=None, ep=None):
    if isinstance(data_loaders, dict):
        test_stats_lst = {}
        test_stats_glb = {}

        for j, (source, data_loader) in enumerate(data_loaders.items()):
            print(f'* Evaluating {source}:')
            seed_j = seed + j if seed else None
            test_stats = _evaluate(data_loader, model, criterion, device, seed_j)
            test_stats_lst[source] = test_stats
            test_stats_glb[source] = test_stats['acc1']

        # apart from individual's acc1, accumulate metrics over all domains to compute mean
        for k in test_stats_lst[source].keys():
            test_stats_glb[k] = torch.tensor([test_stats[k] for test_stats in test_stats_lst.values()]).mean().item()

        return test_stats_glb
    elif isinstance(data_loaders, torch.utils.data.DataLoader): # when args.eval = True
        return _evaluate(data_loaders, model, criterion, device, seed, ep)
    else:
        warnings.warn(f'The structure of {data_loaders} is not recognizable.')
        return _evaluate(data_loaders, model, criterion, device, seed)


```python
def _evaluate(data_loader, model, criterion, device, seed=None, ep=None):
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('n_ways', utils.SmoothedValue(window_size=1, fmt='{value:d}'))
    metric_logger.add_meter('n_imgs', utils.SmoothedValue(window_size=1, fmt='{value:d}'))
    metric_logger.add_meter('acc1', utils.SmoothedValue(window_size=len(data_loader.dataset)))
    metric_logger.add_meter('acc5', utils.SmoothedValue(window_size=len(data_loader.dataset)))


    header = 'Test:'

    # switch to evaluation mode
    model.eval()

    if seed is not None:
        data_loader.generator.manual_seed(seed)

    for ii, batch in enumerate(metric_logger.log_every(data_loader, 10, header)):
        if ep is not None:
            if ii > ep:
                break

        batch = to_device(batch, device)
        SupportTensor, SupportLabel, x, y = batch

        # compute output
        with torch.cuda.amp.autocast():
            output = model(SupportTensor, SupportLabel, x) #logits

        output = output.view(x.shape[0] * x.shape[1], -1)
        y = y.view(-1)
        loss = criterion(output, y)
        acc1, acc5 = accuracy(output, y, topk=(1, 5))

        batch_size = x.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
        metric_logger.update(n_ways=SupportLabel.max()+1)
        metric_logger.update(n_imgs=SupportTensor.shape[1] + x.shape[1])

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))

    ret_dict = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
    ret_dict['acc_std'] = metric_logger.meters['acc1'].std

    return ret_dict

  • 写回答

11条回答 默认 最新

  • 四川底层昏昏pig 2023-06-10 18:21
    关注
    获得0.30元问题酬金

    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 这一行代码中访问了 loss 属性,但是 MetricLogger 类中并没有定义对应的属性。
    你可以在 update 方法中记录 loss 的逻辑改为添加一个新的 meter,然后在 getattr 方法中增加对应的属性访问器。

    评论

报告相同问题?

问题事件

  • 系统已结题 6月18日
  • 修改了问题 6月11日
  • 修改了问题 6月11日
  • 修改了问题 6月11日
  • 展开全部

悬赏问题

  • ¥15 myeclipse 代码调试
  • ¥15 HttpListener() 类监听服务问题
  • ¥15 画个全连接层,帮我看一下我之前的有没有画对
  • ¥15 springboot项目本地运行正常,部署服务器报错。
  • ¥15 Amd显卡覆盖后就用不了blender的cycles GPU渲染了
  • ¥20 帮我解答一下哪一块是单片机将数据传输到显示屏上显示出来的那部分代码,之后帮我解释这个部分的每一句的意思
  • ¥15 matlab 神经网络集成/融合?
  • ¥15 简化数学表达式,修改为简单求和形式
  • ¥20 光热发电的能量汇聚 求解
  • ¥15 windows连接xbox主机