Allen_Smath 2024-11-28 12:12 采纳率: 100%
浏览 4
已结题

模块嵌套引起的神经网络维度错误

神经网络维度错误

尝试将注意力机制和卷积模块融合在一起,出现了一个比较奇怪的事,关键代码如下:

class Feature_Extractor_Enc(nn.Module):
    def __init__(self, emb_size, num_channel):
        super().__init__()
        self.num_channel = num_channel
        self.emb_size = emb_size

        self.temporal_spatial =  nn.Conv2d
        self.spatial_temporal = nn.Conv2d
        self.projection = XXX

        self.attention = TransformerBlock(emb_size=40)

    def forward(self, X):
        x,y = X[0],X[1]
        x = self.temporal_spatial(x) + self.spatial_temporal(x)
        x = self.projection(x)
        x = self.attention(x)

        y = self.temporal_spatial(y) + self.spatial_temporal(y)
        y = self.projection(y)
        y = self.attention(y)
        return (x,y)
class TransformerBlock(nn.Sequential):
    def __init__(self,
                 emb_size,
                 num_heads=10,
                 drop_p=0.5,):
        super().__init__(MultiHeadAttention(emb_size,num_heads,drop_p))
      
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads

        self.keys = nn.Linear(emb_size,emb_size)
        self.queries = nn.Linear(emb_size,emb_size)
        self.values = nn.Linear(emb_size,emb_size)

        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size,emb_size)

    def forward(self,x:Tensor, mask: Tensor=None) -> Tensor:
        queries = rearrange(self.queries(x), 'b n (h d) -> b h n d', h = self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d",h = self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h = self.num_heads)
        energy = torch.einsum("bhqd, bhkd -> bhqk",queries,keys)

        scaling = self.emb_size **(1/2)
        att = F.softmax(energy/scaling,dim=-1)
        att = self.att_drop(att)
        out = torch.einsum("bhal, bhlv -> bhav",att,values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out

报错是说,MultiHeadAttention的输入是(n,e)维的,期待是(b,n,e)维的。我输出了 x = self.attention(x)的输入,是(b,n,e)维的。这是怎么回事呢

  • 写回答

1条回答 默认 最新

  • Allen_Smath 2024-11-28 16:29
    关注

    破案了,是我下面定义了一个一模一样的TransformerBlock,重名了,送给下一个去了

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 系统已结题 12月6日
  • 已采纳回答 11月28日
  • 创建了问题 11月28日