weixin_39185140 2022-08-16 17:25 采纳率: 0%
浏览 97

相对位置编码的Pytorch实现,1d数据

我想询问一下,就是说对于Transformer的相对位置编码是怎么实现的,我是使用Pytorch的,然后处理的数据是1d的,想问问有没有实现过的,我也是尝试了一下,但是感觉实验结果不理想,所以想问问我的是否有错,或者有没有成品给我尝试一下。

import torch
import torch.nn as nn

# 获得相对位置矩阵,这时候还没有乘于可训练参数
def position_distance( Seq ):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    positional_l = torch.arange(Seq, dtype=torch.long, device=device).view(-1, 1)      # 获得长度
    positional_r = torch.arange(Seq, dtype=torch.long, device=device).view(1, -1)      # 获得长度
    distance = positional_l - positional_r      # 相减获得相互距离
    distance = distance + Seq - 1       # 让值都保持为正数
    return distance


class Multihead_Attention(nn.Module):
    def __init__(self, dim, num_heads, Seq):
        super().__init__()

        # Q, K, V 转换矩阵
        self.q = nn.Linear(dim, dim, bias=False)
        self.k = nn.Linear(dim, dim, bias=False)
        self.v = nn.Linear(dim, dim, bias=False)
        self.num_heads = num_heads

        self.position_dim = dim // num_heads     # 因为是计算每个头的相对位置,所以dim要除于head
        # self.Seq_embedding = nn.Embedding(2*Seq-1, self.position_dim)
        self.Seq_embedding = nn.Parameter(torch.zeros((2*Seq-1), self.position_dim))

    def forward(self, x):  # [batch, Seq, dim]

        # *************多头注意力机制*************

        batch_size, Seq, dim = x.shape

        # q k v -> [batch, head, seq, dim]
        q = self.q(x).reshape(batch_size, Seq, self.num_heads, -1).permute(0, 2, 1, 3)
        k = self.k(x).reshape(batch_size, Seq, self.num_heads, -1).permute(0, 2, 1, 3)
        v = self.k(x).reshape(batch_size, Seq, self.num_heads, -1).permute(0, 2, 1, 3)

        # 计算相对位置距离
        distance = position_distance(Seq)
        distance = self.Seq_embedding[distance]     # ->[seq, seq, dim/head]
        # distance = self.Seq_embedding(distance)  # ->[seq, seq, dim/head]
        distance = distance.transpose(1, 2)     # ->[seq, dim/head, seq]

        # 计算q和distance相乘
        q_distance = q.permute(2, 0, 1, 3).reshape(Seq, batch_size*self.num_heads, self.position_dim)   # ->[seq, batch*head, dim]
        QmD = (q_distance @ distance).reshape(Seq, batch_size, self.num_heads, Seq).permute(1, 2, 0, 3)    # ->[dim, head, seq, seq]

        # 点积得到attention score
        MultiHead_attn = ((q@k.transpose(2, 3)) + QmD) * (self.position_dim ** -0.5)      # -> [batch, head, seq, seq]
        MultiHead_attn = MultiHead_attn.softmax(dim=-1)

        # 乘上attention score并输出  -> [batch, dim, Seq]
        MultiHead_attn = (MultiHead_attn @ v).permute(0, 2, 1, 3).reshape(batch_size, Seq, dim)

        return MultiHead_attn



  • 写回答

2条回答 默认 最新

  • kakaccys 2022-08-16 18:00
    关注

    楼主,相对位置实现,你可以参考这个网址:
    https://blog.csdn.net/cyz0202/article/details/124929307

    评论

报告相同问题?

问题事件

  • 创建了问题 8月16日

悬赏问题

  • ¥15 写一个方法checkPerson,入参实体类Person,出参布尔值
  • ¥15 我想咨询一下路面纹理三维点云数据处理的一些问题,上传的坐标文件里是怎么对无序点进行编号的,以及xy坐标在处理的时候是进行整体模型分片处理的吗
  • ¥15 CSAPPattacklab
  • ¥15 一直显示正在等待HID—ISP
  • ¥15 Python turtle 画图
  • ¥15 关于大棚监测的pcb板设计
  • ¥15 stm32开发clion时遇到的编译问题
  • ¥15 lna设计 源简并电感型共源放大器
  • ¥15 如何用Labview在myRIO上做LCD显示?(语言-开发语言)
  • ¥15 Vue3地图和异步函数使用