shilei09 2021-02-10 15:57 采纳率: 50%
浏览 1139
已采纳

pytorch源码中“if torch._C._get_tracing_state():”怎么理解

 

 

 

    def _call_impl(self, *input, **kwargs):
        for hook in itertools.chain(
                _global_forward_pre_hooks.values(),
                self._forward_pre_hooks.values()):
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():  # 我对这句话不理解,if中为什么没有使用输入的形参?
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)  # 错误提示跳转到了这里
  • 写回答

3条回答 默认 最新

  • PyBigStar 2021-02-12 09:51
    关注

    if torch._C._get_tracing_state():就是用来判断是否使用JIT来跟踪模型。像pytorch构建一个计算图就会用到一个中央的context去管理变量,而JIT跟踪模型也类似,比如用以下方式标记了这个module需要用JIT跟踪:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.jit import ScriptModule, script_method, trace
    
    class MyScriptModule(ScriptModule):
    
        def __init__(self):
            super(MyScriptModule, self).__init__()
            # trace produces a ScriptModule's conv1 and conv2
            self.conv1 = trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
            self.conv2 = trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
    
        @script_method
        def forward(self, input):
          x = F.relu(self.conv1(input))
          x = F.relu(self.conv2(input))
          return x
    

    而torch跟踪代码后同样也会在c++后端中存入这个MyScriptModule是否需要跟踪的信息,因此只是需要调用_C的_get_tracing_state()就可以判断是否需要跟踪这个module,而不需要使用forward中传入的形参。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

悬赏问题

  • ¥15 如何删除这个虚拟音频
  • ¥50 hyper默认的default switch
  • ¥15 网站打不开,提示502 Bad Gateway
  • ¥20 基于MATLAB的绝热压缩空气储能系统代码咨询
  • ¥15 R语言建立随机森林模型出现的问题
  • ¥20 unity内置语言切换的按钮设置
  • ¥15 中级微观经济学,生产可能性边界问题
  • ¥15 TCP传输时不同网卡传输用时差异过大
  • ¥15 请各位看看我写的属于什么算法,或者有更正确的写法?
  • ¥15 html5 qrcode 扫描器