logistic_ 2023-05-11 11:01 采纳率: 0%
浏览 74

用c3str替换v8中的c2f模块,报错ZeroDivisionError: integer division or modulo by zero,求改正!

最近想将yolov8中backbone部分的c2f模块替换成c3str模块(输入输出相同,直接替换应该没问题的),需要加入WindowAttention(nn.Module),代码如下:

class WindowAttention(nn.Module):

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):

        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        # print(attn.dtype, v.dtype)
        try:
            x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        except:
            #print(attn.dtype, v.dtype)
            x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

但是运行时报错:ZeroDivisionError: integer division or modulo by zero,

img

虽然知道问题是因为这里num_heads参数为0,但不知道为啥为0,这里应该不用修改才对,不知道有没有了解的朋友能看一下

  • 写回答

2条回答 默认 最新

  • 爱晚乏客游 2023-05-11 12:57
    关注

    c3str是你自己实现的?我没有找到这个模块。
    而你说的输入输出相同,是有打印输出看他们对应的shape一样吗?还是代码参数一样就认为输入输出相同?
    如果你换回c2f模块这里会报错吗?

    评论 编辑记录

报告相同问题?

问题事件

  • 创建了问题 5月11日

悬赏问题

  • ¥15 请各位帮我看看是哪里出了问题
  • ¥15 vs2019的js智能提示
  • ¥15 关于#开发语言#的问题:FDTD建模问题图中代码没有报错,但是模型却变透明了
  • ¥15 uniapp的h5项目写一个抽奖动画
  • ¥15 TeleScan不能修改bar
  • ¥100 请问我基于逐飞库写的这个有关于mp u6050传感器的函数,为什么输出的值是固定的?
  • ¥15 hadoop中启动hive报错如下怎么解决
  • ¥15 如何优化QWebEngineView 加载url的速度
  • ¥15 关于#hadoop#的问题,请各位专家解答!
  • ¥15 如何批量抓取网站信息