shilei09 2021-02-10 15:57 采纳率: 50%
浏览 1151
已采纳

pytorch源码中“if torch._C._get_tracing_state():”怎么理解

 

 

 

    def _call_impl(self, *input, **kwargs):
        for hook in itertools.chain(
                _global_forward_pre_hooks.values(),
                self._forward_pre_hooks.values()):
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():  # 我对这句话不理解,if中为什么没有使用输入的形参?
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)  # 错误提示跳转到了这里
  • 写回答

3条回答 默认 最新

  • PyBigStar 2021-02-12 09:51
    关注

    if torch._C._get_tracing_state():就是用来判断是否使用JIT来跟踪模型。像pytorch构建一个计算图就会用到一个中央的context去管理变量,而JIT跟踪模型也类似,比如用以下方式标记了这个module需要用JIT跟踪:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.jit import ScriptModule, script_method, trace
    
    class MyScriptModule(ScriptModule):
    
        def __init__(self):
            super(MyScriptModule, self).__init__()
            # trace produces a ScriptModule's conv1 and conv2
            self.conv1 = trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
            self.conv2 = trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
    
        @script_method
        def forward(self, input):
          x = F.relu(self.conv1(input))
          x = F.relu(self.conv2(input))
          return x
    

    而torch跟踪代码后同样也会在c++后端中存入这个MyScriptModule是否需要跟踪的信息,因此只是需要调用_C的_get_tracing_state()就可以判断是否需要跟踪这个module,而不需要使用forward中传入的形参。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

悬赏问题

  • ¥15 metadata提取的PDF元数据,如何转换为一个Excel
  • ¥15 关于arduino编程toCharArray()函数的使用
  • ¥100 vc++混合CEF采用CLR方式编译报错
  • ¥15 coze 的插件输入飞书多维表格 app_token 后一直显示错误,如何解决?
  • ¥15 vite+vue3+plyr播放本地public文件夹下视频无法加载
  • ¥15 c#逐行读取txt文本,但是每一行里面数据之间空格数量不同
  • ¥50 如何openEuler 22.03上安装配置drbd
  • ¥20 ING91680C BLE5.3 芯片怎么实现串口收发数据
  • ¥15 无线连接树莓派,无法执行update,如何解决?(相关搜索:软件下载)
  • ¥15 Windows11, backspace, enter, space键失灵