shilei09 2021-02-10 15:57 采纳率: 50%
浏览 1130
已采纳

pytorch源码中“if torch._C._get_tracing_state():”怎么理解

 

 

 

    def _call_impl(self, *input, **kwargs):
        for hook in itertools.chain(
                _global_forward_pre_hooks.values(),
                self._forward_pre_hooks.values()):
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():  # 我对这句话不理解,if中为什么没有使用输入的形参?
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)  # 错误提示跳转到了这里
  • 写回答

3条回答 默认 最新

  • PyBigStar 2021-02-12 09:51
    关注

    if torch._C._get_tracing_state():就是用来判断是否使用JIT来跟踪模型。像pytorch构建一个计算图就会用到一个中央的context去管理变量,而JIT跟踪模型也类似,比如用以下方式标记了这个module需要用JIT跟踪:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.jit import ScriptModule, script_method, trace
    
    class MyScriptModule(ScriptModule):
    
        def __init__(self):
            super(MyScriptModule, self).__init__()
            # trace produces a ScriptModule's conv1 and conv2
            self.conv1 = trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
            self.conv2 = trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
    
        @script_method
        def forward(self, input):
          x = F.relu(self.conv1(input))
          x = F.relu(self.conv2(input))
          return x
    

    而torch跟踪代码后同样也会在c++后端中存入这个MyScriptModule是否需要跟踪的信息,因此只是需要调用_C的_get_tracing_state()就可以判断是否需要跟踪这个module,而不需要使用forward中传入的形参。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

悬赏问题

  • ¥20 测距传感器数据手册i2c
  • ¥15 RPA正常跑,cmd输入cookies跑不出来
  • ¥15 求帮我调试一下freefem代码
  • ¥15 matlab代码解决,怎么运行
  • ¥15 R语言Rstudio突然无法启动
  • ¥15 关于#matlab#的问题:提取2个图像的变量作为另外一个图像像元的移动量,计算新的位置创建新的图像并提取第二个图像的变量到新的图像
  • ¥15 改算法,照着压缩包里边,参考其他代码封装的格式 写到main函数里
  • ¥15 用windows做服务的同志有吗
  • ¥60 求一个简单的网页(标签-安全|关键词-上传)
  • ¥35 lstm时间序列共享单车预测,loss值优化,参数优化算法