神经网络维度错误
尝试将注意力机制和卷积模块融合在一起,出现了一个比较奇怪的事,关键代码如下:
class Feature_Extractor_Enc(nn.Module):
def __init__(self, emb_size, num_channel):
super().__init__()
self.num_channel = num_channel
self.emb_size = emb_size
self.temporal_spatial = nn.Conv2d
self.spatial_temporal = nn.Conv2d
self.projection = XXX
self.attention = TransformerBlock(emb_size=40)
def forward(self, X):
x,y = X[0],X[1]
x = self.temporal_spatial(x) + self.spatial_temporal(x)
x = self.projection(x)
x = self.attention(x)
y = self.temporal_spatial(y) + self.spatial_temporal(y)
y = self.projection(y)
y = self.attention(y)
return (x,y)
class TransformerBlock(nn.Sequential):
def __init__(self,
emb_size,
num_heads=10,
drop_p=0.5,):
super().__init__(MultiHeadAttention(emb_size,num_heads,drop_p))
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size, num_heads, dropout):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
self.keys = nn.Linear(emb_size,emb_size)
self.queries = nn.Linear(emb_size,emb_size)
self.values = nn.Linear(emb_size,emb_size)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size,emb_size)
def forward(self,x:Tensor, mask: Tensor=None) -> Tensor:
queries = rearrange(self.queries(x), 'b n (h d) -> b h n d', h = self.num_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d",h = self.num_heads)
values = rearrange(self.values(x), "b n (h d) -> b h n d", h = self.num_heads)
energy = torch.einsum("bhqd, bhkd -> bhqk",queries,keys)
scaling = self.emb_size **(1/2)
att = F.softmax(energy/scaling,dim=-1)
att = self.att_drop(att)
out = torch.einsum("bhal, bhlv -> bhav",att,values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out
报错是说,MultiHeadAttention的输入是(n,e)维的,期待是(b,n,e)维的。我输出了 x = self.attention(x)的输入,是(b,n,e)维的。这是怎么回事呢