谦虚且进步 2023-03-18 14:37 采纳率: 90%
浏览 17
已结题

关于#人工智能#的问题:我现在知道dense attention是一种简单的注意力机制打分函数,请问还有没有其他打分函数的得思路及代码实现

我现在知道dense attention是一种简单的注意力机制打分函数,请问还有没有其他打分函数的得思路及代码实现

  • 写回答

2条回答 默认 最新

  • 追cium 2023-03-18 14:43
    关注

    参考GPT和自己的思路:

    当涉及到注意力机制打分函数时,有许多不同的方法可以实现。以下是一些流行的打分函数及其代码实现:

    1. 单层感知器:
    class MLPAttention(torch.nn.Module):
        def __init__(self, input_size, fc_size):
            super(MLPAttention, self).__init__()
            self.fc = torch.nn.Linear(input_size, fc_size)
            self.attn = torch.nn.Linear(fc_size, 1)
    
        def forward(self, encoder_states):
            energy = torch.tanh(self.fc(encoder_states))
            attention_scores = self.attn(energy)
            attention_weights = torch.softmax(attention_scores, dim=1)
            context_vector = attention_weights * encoder_states
            return context_vector, attention_weights
    
    1. 缩放点积注意力:
    class ScaledDotProductAttention(torch.nn.Module):
        def __init__(self):
            super(ScaledDotProductAttention, self).__init__()
    
        def forward(self, query, key, value, mask=None):
            scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
            if mask is not None:
                scores = scores.masked_fill(mask == 0, -1e9)
            softmax_scores = torch.softmax(scores, dim=-1)
            output = torch.matmul(softmax_scores, value)
            return output, softmax_scores
    
    1. 双向注意力:
    class BidirectionalAttention(torch.nn.Module):
        def __init__(self, input_size, att_size):
            super(BidirectionalAttention, self).__init__()
            self.W_s = torch.nn.Linear(input_size, att_size)
            self.W_t = torch.nn.Linear(input_size, att_size)
            self.W_a = torch.nn.Linear(att_size, 1)
    
        def forward(self, s, t):
            s_proj = self.W_s(s)
            t_proj = self.W_t(t)
            scores = self.W_a(torch.tanh(s_proj + t_proj.unsqueeze(1))).squeeze(2)
            alpha = torch.softmax(scores, dim=-1)
            t_attend = torch.bmm(alpha.unsqueeze(1), t).squeeze(1)
            s_attend = torch.bmm(alpha.unsqueeze(2), s.unsqueeze(1)).squeeze(2)
            return s_attend, t_attend
    

    这只是一些打分函数的例子,当然还有很多其他的打分函数。选择合适的打分函数取决于问题的具体情况。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 5月23日
  • 已采纳回答 5月15日
  • 创建了问题 3月18日

悬赏问题

  • ¥15 无法输出helloworld
  • ¥15 高通uboot 打印ubi init err 22
  • ¥20 PDF元数据中的XMP媒体管理属性
  • ¥15 R语言中lasso回归报错
  • ¥15 网站突然不能访问了,上午还好好的
  • ¥15 有没有dl可以帮弄”我去图书馆”秒选道具和积分
  • ¥15 semrush,SEO,内嵌网站,api
  • ¥15 Stata:为什么reghdfe后的因变量没有被发现识别啊
  • ¥15 振荡电路,ADS仿真
  • ¥15 关于#c语言#的问题,请各位专家解答!