multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
attn_output = multihead_attn(query, key, value)[0]
output: torch.Size([12, 64, 300])
batch_size 为 64,有 12 个词,每个词的向量是 300 维
我想知道这个query、key、value是经过线性变换前的还是已经变换后的。
比如原来的基础embding是x、y。
这个地方应该是query = y,key =x,value = x
还是 query =wy ,key =wx, value = wx ,其中w为训练参数。