石大攻城狮 2023-02-18 12:48 采纳率: 100%
浏览 4
已结题

关于#人工智能#的问题:这个地方应该是query = y,value = x还是 query =wy ,key =wx, value = wx ,其中w为训练参数

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
attn_output = multihead_attn(query, key, value)[0]

output: torch.Size([12, 64, 300])

batch_size 为 64,有 12 个词,每个词的向量是 300 维

我想知道这个query、key、value是经过线性变换前的还是已经变换后的。

比如原来的基础embding是x、y。

这个地方应该是query = y,key =x,value = x
还是 query =wy ,key =wx, value = wx ,其中w为训练参数。

  • 写回答

1条回答 默认 最新

  • CodeBytes 2023-02-18 13:17
    关注

    该回答引用ChatGPT

    根据 nn.MultiheadAttention 的文档,query, key, value 输入的形状应该是 (seq_len, batch_size, embed_dim),即每个时间步的输入向量形状是 (batch_size, embed_dim)。这里的 query, key, value 应该是经过线性变换后的向量。

    在 nn.MultiheadAttention 的初始化中,有三个线性层,分别对应 query, key, value 的线性变换,可以使用 nn.Linear 模块来实现,其中 in_features 表示输入向量的维度,out_features 表示输出向量的维度。比如:

    query_linear = nn.Linear(embed_dim, embed_dim)
    key_linear = nn.Linear(embed_dim, embed_dim)
    value_linear = nn.Linear(embed_dim, embed_dim)
    
    query = query_linear(y)
    key = key_linear(x)
    value = value_linear(x)
    
    

    这样,query, key, value 就是经过线性变换后的向量。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 系统已结题 2月26日
  • 已采纳回答 2月18日
  • 创建了问题 2月18日

悬赏问题

  • ¥17 pro*C预编译“闪回查询”报错SCN不能识别
  • ¥15 微信会员卡接入微信支付商户号收款
  • ¥15 如何获取烟草零售终端数据
  • ¥15 数学建模招标中位数问题
  • ¥15 phython路径名过长报错 不知道什么问题
  • ¥15 深度学习中模型转换该怎么实现
  • ¥15 HLs设计手写数字识别程序编译通不过
  • ¥15 Stata外部命令安装问题求帮助!
  • ¥15 从键盘随机输入A-H中的一串字符串,用七段数码管方法进行绘制。提交代码及运行截图。
  • ¥15 TYPCE母转母,插入认方向