菜卵一枚 2021-05-10 10:46 采纳率: 0%
浏览 395

模型训练中,model size和inference time有联系吗?

在我们的固有观念里面认为model size越大inference time随之变大

但在近期做试验的过程中发现U2-Net†仅有4MB,但inference time有371

而U-NET有7MB,inference time为58

 

batch size同为12,同GPU,服务器未运行其他任何程序。

 

所以两者是否真的有联系?

  • 写回答

2条回答 默认 最新

  • GitCode 官方 企业官方账号 2021-05-10 17:47
    关注

    其实算inference time比较复杂吧,这里简单提两个要注意的点: asynchronous execution and GPU warm up

    具体细节可以看下这篇文章: https://towardsdatascience.com/the-correct-way-to-measure-inference-time-of-deep-neural-networks-304a54e5187f

    下面上一份pytorch算inference time的代码:

    import torch
    import numpy as np
    import torchvision.models as models
    
    model = models.vgg16()
    device = torch.device("cuda")
    model.to(device)
    model.eval()
    dummy_input = torch.randn(1, 3, 224, 224, dtype = torch.float).to(device)
    starter, ender = torch.cuda.Event(enable_timing = True), torch.cuda.Event(enable_timing = True)
    repetitions = 300
    timings = np.zeros((repetitions, 1))
    #GPU-WARM-UP
    for _ in range(10):
       _ = model(dummy_input)
    # MEASURE PERFORMANCE
    with torch.no_grad():
      for rep in range(repetitions):
         starter.record()
         _ = model(dummy_input)
         ender.record()
         # WAIT FOR GPU SYNC
         torch.cuda.synchronize()
         curr_time = starter.elapsed_time(ender)
         timings[rep] = curr_time
    mean_syn = np.sum(timings) / repetitions
    std_syn = np.std(timings)
    print(mean_syn)
    print(std_syn)
    评论

报告相同问题?

悬赏问题

  • ¥15 Python输入字符串转化为列表排序具体见图,严格按照输入
  • ¥20 XP系统在重新启动后进不去桌面,一直黑屏。
  • ¥15 opencv图像处理,需要四个处理结果图
  • ¥15 无线移动边缘计算系统中的系统模型
  • ¥15 深度学习中的画图问题
  • ¥15 java报错:使用mybatis plus查询一个只返回一条数据的sql,却报错返回了1000多条
  • ¥15 Python报错怎么解决
  • ¥15 simulink如何调用DLL文件
  • ¥15 关于用pyqt6的项目开发该怎么把前段后端和业务层分离
  • ¥30 线性代数的问题,我真的忘了线代的知识了