shilei09 2021-02-10 15:57 采纳率: 50%
浏览 1179
已采纳

pytorch源码中“if torch._C._get_tracing_state():”怎么理解

 

 

 

    def _call_impl(self, *input, **kwargs):
        for hook in itertools.chain(
                _global_forward_pre_hooks.values(),
                self._forward_pre_hooks.values()):
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():  # 我对这句话不理解,if中为什么没有使用输入的形参?
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)  # 错误提示跳转到了这里
  • 写回答

3条回答 默认 最新

  • PyBigStar 2021-02-12 09:51
    关注

    if torch._C._get_tracing_state():就是用来判断是否使用JIT来跟踪模型。像pytorch构建一个计算图就会用到一个中央的context去管理变量,而JIT跟踪模型也类似,比如用以下方式标记了这个module需要用JIT跟踪:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.jit import ScriptModule, script_method, trace
    
    class MyScriptModule(ScriptModule):
    
        def __init__(self):
            super(MyScriptModule, self).__init__()
            # trace produces a ScriptModule's conv1 and conv2
            self.conv1 = trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
            self.conv2 = trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
    
        @script_method
        def forward(self, input):
          x = F.relu(self.conv1(input))
          x = F.relu(self.conv2(input))
          return x
    

    而torch跟踪代码后同样也会在c++后端中存入这个MyScriptModule是否需要跟踪的信息,因此只是需要调用_C的_get_tracing_state()就可以判断是否需要跟踪这个module,而不需要使用forward中传入的形参。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

悬赏问题

  • ¥15 PADS Logic 原理图
  • ¥15 PADS Logic 图标
  • ¥15 电脑和power bi环境都是英文如何将日期层次结构转换成英文
  • ¥20 气象站点数据求取中~
  • ¥15 如何获取APP内弹出的网址链接
  • ¥15 wifi 图标不见了 不知道怎么办 上不了网 变成小地球了
  • ¥50 STM32单片机传感器读取错误
  • ¥15 (关键词-阻抗匹配,HFSS,RFID标签天线)
  • ¥15 机器人轨迹规划相关问题
  • ¥15 word样式右侧翻页键消失