if torch._C._get_tracing_state():就是用来判断是否使用JIT来跟踪模型。像pytorch构建一个计算图就会用到一个中央的context去管理变量,而JIT跟踪模型也类似,比如用以下方式标记了这个module需要用JIT跟踪:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import ScriptModule, script_method, trace
class MyScriptModule(ScriptModule):
def __init__(self):
super(MyScriptModule, self).__init__()
# trace produces a ScriptModule's conv1 and conv2
self.conv1 = trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
self.conv2 = trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
@script_method
def forward(self, input):
x = F.relu(self.conv1(input))
x = F.relu(self.conv2(input))
return x
而torch跟踪代码后同样也会在c++后端中存入这个MyScriptModule是否需要跟踪的信息,因此只是需要调用_C的_get_tracing_state()就可以判断是否需要跟踪这个module,而不需要使用forward中传入的形参。