从零开始写代码 2025-06-02 18:55 采纳率: 69.2%
浏览 18
已结题

Pytorch TracerWarning 如何消除?

img

如何消除

TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  sqrt_HW = torch.sqrt(torch.tensor(

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  sqrt_HW = torch.sqrt(torch.tensor(

源码如下:
token_transformer.py


    def forward(self, x):
        if not self.expand_disabled:
            x_emb = self.fc1(x)
        else:
            x_emb = x
        x_attn = x_emb[:, :, 0:self.attn_dim]
        x_conv = x_emb[:, :, self.attn_dim:]
        B, HW, C = x_conv.shape
        # 修改后的张量操作
        HW = x_conv.size(1)  # 获取HW作为张量值

# 警告位置
        **sqrt_HW = torch.sqrt(torch.tensor(
            HW, dtype=torch.float32, device=x.device))
        H = W = torch.round(sqrt_HW).to(torch.int64)**


        x_conv = x_conv.transpose(1, 2).reshape(B, C, H, W)  # 使用张量H/W

        x_attn = self.attn(self.norm1(x_attn))
        # x_attn = self.attn(self.norm1(x_attn.transpose(1,2)).transpose(1,2))

        x_conv = self.conv_branch(x_conv)
        B, C, H, W = x_conv.shape
        x_conv = x_conv.reshape(B, C, H*W).transpose(1, 2)

        # x = self.split_ratio * x_attn + (1 - self.split_ratio) * x_conv
        x, w_attn, w_conv = self.fuse(x_attn, x_conv)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x, w_attn, w_conv

asf_former.py

    def forward(self, x):
        w_attn_depth = []
        w_conv_depth = []

        # step0: soft split
        x = self.soft_split0(x).transpose(1, 2)

        # iteration1: re-structurization/reconstruction
        x, w_attn, w_conv = self.attention1(x)
        w_attn_depth.append(w_attn.mean())
        w_conv_depth.append(w_conv.mean())

        # 处理不同形状的输入张量
        B, new_HW, C = x.shape


# 警告位置
        **sqrt_HW = torch.sqrt(torch.tensor(
            new_HW, dtype=torch.float32, device=x.device))
        H = W = torch.round(sqrt_HW).to(torch.int64)**


        x = x.transpose(1, 2).reshape(B, C, H, W)
        # iteration1: soft split
        x = self.soft_split1(x).transpose(1, 2)

        # iteration2: re-structurization/reconstruction
        x, w_attn, w_conv = self.attention2(x)
        w_attn_depth.append(w_attn.mean())
        w_conv_depth.append(w_conv.mean())

        B, new_HW, C = x.shape

# 警告位置
        sqrt_HW = torch.sqrt(torch.tensor(
            new_HW, dtype=torch.float32, device=x.device))
        H = W = torch.round(sqrt_HW).to(torch.int64)

        x = x.transpose(1, 2).reshape(B, C, H, W)
        # iteration2: soft split
        x = self.soft_split2(x).transpose(1, 2)

        # final tokens
        x = self.project(x)

        return x, w_attn_depth, w_conv_depth

transformer_block.py

    def forward(self, x):
        if not self.expand_disabled:
            x_emb = self.fc1(x)
        else:
            x_emb = x
        x_attn = x_emb[:, :, 0:self.attn_dim]
        x_cls = x_emb[:, 0, self.attn_dim:].unsqueeze(1)
        x_conv = x_emb[:, 1:, self.attn_dim:]
        B, HW, C = x_conv.shape

# 警告位置
        **HW_tensor = torch.tensor(
            HW, dtype=torch.float32, device=x_conv.device)  # 显式指定设备
        sqrt_HW = torch.sqrt(HW_tensor)
        H = W = torch.round(sqrt_HW).to(dtype=torch.int64)  # 保持张量操作**


        x_conv = x_conv.transpose(1, 2).reshape(
            B, C, H, W)
        x_attn = self.attn(self.norm1(x_attn))
        # x_attn = self.attn(self.norm1(x_attn.transpose(1,2)).transpose(1,2))

        x_conv = self.conv_branch(x_conv)
        B, C, H, W = x_conv.shape
        x_conv = x_conv.reshape(B, C, H*W).transpose(1, 2)
        x_conv = torch.cat((x_cls, x_conv), dim=1)

        # x = x + self.drop_path(self.fc2(self.split_ratio * x_attn + (1 - self.split_ratio) * x_conv))
        x_fuse, w_attn, w_conv = self.fuse(x_attn, x_conv)
        x = x + self.drop_path(self.fc2(x_fuse))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x, w_attn, w_conv

其实之前有尝试过改为sourceTensor.clone().detach()

img

img

但是后面又报

C:\Users\范先生\Desktop\大三\图像处理_2419\农作物病虫害检测\Leaf_Disease\models\token_transformer.py:152: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  sqrt_HW = torch.sqrt(torch.tensor(HW, dtype=torch.float32))
C:\Users\范先生\Desktop\大三\图像处理_2419\农作物病虫害检测\Leaf_Disease\models\token_transformer.py:152: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  sqrt_HW = torch.sqrt(torch.tensor(HW, dtype=torch.float32))
Traceback (most recent call last):
  File "G:\anaconda\envs\pytorch\lib\site-packages\torch\jit\_trace.py", line 477, in run_mod_and_filter_tensor_outputs
    outs = wrap_retval(mod(*_clone_inputs(inputs)))
  File "G:\anaconda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\anaconda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\范先生\Desktop\大三\图像处理_2419\农作物病虫害检测\Leaf_Disease\tools\common_tools_torchmetrics_old.py", line 86, in forward
    asf_output = self.asf_former(x)
  File "G:\anaconda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\anaconda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\范先生\Desktop\大三\图像处理_2419\农作物病虫害检测\Leaf_Disease\models\asf_former.py", line 273, in forward
    x, w_attn_depth, w_conv_depth, w_attn_class, w_conv_class = self.forward_features(
  File "C:\Users\范先生\Desktop\大三\图像处理_2419\农作物病虫害检测\Leaf_Disease\models\asf_former.py", line 238, in forward_features
    x, w_attn_depth, w_conv_depth = self.tokens_to_token(x)
  File "G:\anaconda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\anaconda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\范先生\Desktop\大三\图像处理_2419\农作物病虫害检测\Leaf_Disease\models\asf_former.py", line 125, in forward
    sqrt_HW = torch.sqrt(new_HW.clone().detach().to(torch.float32))
AttributeError: 'int' object has no attribute 'clone'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "C:\Users\范先生\Desktop\大三\图像处理_2419\农作物病虫害检测\Leaf_Disease\train_CustomNet.py", line 159, in <module>
    writer.add_graph(custom_model, input_to_model=torch.rand(
  File "G:\anaconda\envs\pytorch\lib\site-packages\torch\utils\tensorboard\writer.py", line 841, in add_graph
    graph(model, input_to_model, verbose, use_strict_trace)
  File "G:\anaconda\envs\pytorch\lib\site-packages\torch\utils\tensorboard\_pytorch_graph.py", line 331, in graph
    trace = torch.jit.trace(model, args, strict=use_strict_trace)
  File "G:\anaconda\envs\pytorch\lib\site-packages\torch\jit\_trace.py", line 1002, in trace
    traced_func = _trace_impl(
  File "G:\anaconda\envs\pytorch\lib\site-packages\torch\jit\_trace.py", line 698, in _trace_impl
    return trace_module(
  File "G:\anaconda\envs\pytorch\lib\site-packages\torch\jit\_trace.py", line 1306, in trace_module
    _check_trace(
  File "G:\anaconda\envs\pytorch\lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "G:\anaconda\envs\pytorch\lib\site-packages\torch\jit\_trace.py", line 583, in _check_trace
    fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, "Python function")
  File "G:\anaconda\envs\pytorch\lib\site-packages\torch\jit\_trace.py", line 483, in run_mod_and_filter_tensor_outputs
    raise TracingCheckError(
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
encountered an exception while running the Python function with test inputs.
Exception:
    'int' object has no attribute 'clone'  

后面又改回来了

img

img

结果依旧报第一个追踪的警告,不知道如何解决。

  • 写回答

6条回答 默认 最新

  • 宝码香车 优质创作者: JavaScript技术领域 2025-06-02 19:03
    关注

    让【宝码香车】来帮你解答,本回答参考DeepSeek编写,并整理提供。

    如果还有疑问可以点击头像关注私信或评论。如果答案让您满意,请采纳、点赞、关注,非常感谢!

    解决方案:消除 PyTorch TracerWarning 和 UserWarning

    您遇到的警告是由于在模型前向传播中使用 torch.tensor() 从整数创建张量导致的。在 PyTorch 的 JIT 追踪(tracing)模式下,这种操作会被视为常量,可能导致追踪结果不准确。以下是完整的解决方案:

    核心问题分析

    1. TracerWarningtorch.tensor() 在追踪过程中被注册为常量
    2. UserWarning:推荐使用 sourceTensor.clone().detach() 替代 torch.tensor(sourceTensor)
    3. 后续错误:尝试对整数(int)调用 .clone().detach() 导致 AttributeError(整数没有这些方法)

    修复方案(3 个文件修改)

    使用数学计算直接处理整数,避免创建中间张量:

    # token_transformer.py
    import math  # 添加导入
    
    
    def forward(self, x):
        # ... 原有代码 ...
        B, HW, C = x_conv.shape
        # === 修复代码 ===
        H = W = int(round(math.sqrt(HW)))  # 直接使用数学计算
        # ===============
        x_conv = x_conv.transpose(1, 2).reshape(B, C, H, W)
        # ... 后续代码 ...
    
    # asf_former.py
    import math  # 添加导入
    
    
    def forward(self, x):
        # ... 原有代码 ...
        B, new_HW, C = x.shape
        # === 修复代码(两处相同) ===
        H = W = int(round(math.sqrt(new_HW)))  # 直接使用数学计算
        # ========================
        x = x.transpose(1, 2).reshape(B, C, H, W)
        # ... 后续代码 ...
    
    # transformer_block.py
    import math  # 添加导入
    
    
    def forward(self, x):
        # ... 原有代码 ...
        B, HW, C = x_conv.shape
        # === 修复代码 ===
        H = W = int(round(math.sqrt(HW)))  # 直接使用数学计算
        # ===============
        x_conv = x_conv.transpose(1, 2).reshape(B, C, H, W)
        # ... 后续代码 ...
    

    关键修改说明

    1. 使用 math.sqrt() 替代 torch.sqrt()
      直接对整数进行数学计算,避免创建张量
    2. 整数运算流程
    graph LR
    A[获取HW整数] --> B[计算平方根]
    B --> C[四舍五入取整]
    C --> D[转换为整数]
    D --> E[直接用于reshape]
    
    1. 为什么有效
      • 完全避开张量创建操作
      • 整数计算不会被 JIT 追踪视为变量
      • 保持与原始逻辑相同的数值结果

    验证要点

    1. 确保所有输入尺寸 HW完全平方数(如 196=14×14)
    2. 如果遇到非平方数输入,需检查模型结构:
    if not math.isqrt(HW) ** 2 == HW:
        raise ValueError(f"HW={HW} is not a perfect square")
    

    替代方案(不推荐)

    如果必须保持张量操作,可改用以下方法(仍可能产生警告):

    # 仅适用于 HW 是张量的情况(非整数)
    sqrt_HW = torch.sqrt(HW.clone().detach().to(torch.float32))
    H = W = torch.round(sqrt_HW).to(torch.int64)
    

    注意:此方案仅当 HW 本身是张量时有效,在您提供的代码中 HW 是从 shape 获取的整数,故不适用。

    此解决方案通过彻底消除张量创建操作,同时解决了警告和后续错误问题,且保持原始逻辑不变。实际修改量小,只需添加 import math 并替换 3 处计算代码。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(5条)

报告相同问题?

问题事件

  • 系统已结题 6月17日
  • 已采纳回答 6月9日
  • 创建了问题 6月2日