我想询问一下,就是说对于Transformer的相对位置编码是怎么实现的,我是使用Pytorch的,然后处理的数据是1d的,想问问有没有实现过的,我也是尝试了一下,但是感觉实验结果不理想,所以想问问我的是否有错,或者有没有成品给我尝试一下。
import torch
import torch.nn as nn
# 获得相对位置矩阵,这时候还没有乘于可训练参数
def position_distance( Seq ):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
positional_l = torch.arange(Seq, dtype=torch.long, device=device).view(-1, 1) # 获得长度
positional_r = torch.arange(Seq, dtype=torch.long, device=device).view(1, -1) # 获得长度
distance = positional_l - positional_r # 相减获得相互距离
distance = distance + Seq - 1 # 让值都保持为正数
return distance
class Multihead_Attention(nn.Module):
def __init__(self, dim, num_heads, Seq):
super().__init__()
# Q, K, V 转换矩阵
self.q = nn.Linear(dim, dim, bias=False)
self.k = nn.Linear(dim, dim, bias=False)
self.v = nn.Linear(dim, dim, bias=False)
self.num_heads = num_heads
self.position_dim = dim // num_heads # 因为是计算每个头的相对位置,所以dim要除于head
# self.Seq_embedding = nn.Embedding(2*Seq-1, self.position_dim)
self.Seq_embedding = nn.Parameter(torch.zeros((2*Seq-1), self.position_dim))
def forward(self, x): # [batch, Seq, dim]
# *************多头注意力机制*************
batch_size, Seq, dim = x.shape
# q k v -> [batch, head, seq, dim]
q = self.q(x).reshape(batch_size, Seq, self.num_heads, -1).permute(0, 2, 1, 3)
k = self.k(x).reshape(batch_size, Seq, self.num_heads, -1).permute(0, 2, 1, 3)
v = self.k(x).reshape(batch_size, Seq, self.num_heads, -1).permute(0, 2, 1, 3)
# 计算相对位置距离
distance = position_distance(Seq)
distance = self.Seq_embedding[distance] # ->[seq, seq, dim/head]
# distance = self.Seq_embedding(distance) # ->[seq, seq, dim/head]
distance = distance.transpose(1, 2) # ->[seq, dim/head, seq]
# 计算q和distance相乘
q_distance = q.permute(2, 0, 1, 3).reshape(Seq, batch_size*self.num_heads, self.position_dim) # ->[seq, batch*head, dim]
QmD = (q_distance @ distance).reshape(Seq, batch_size, self.num_heads, Seq).permute(1, 2, 0, 3) # ->[dim, head, seq, seq]
# 点积得到attention score
MultiHead_attn = ((q@k.transpose(2, 3)) + QmD) * (self.position_dim ** -0.5) # -> [batch, head, seq, seq]
MultiHead_attn = MultiHead_attn.softmax(dim=-1)
# 乘上attention score并输出 -> [batch, dim, Seq]
MultiHead_attn = (MultiHead_attn @ v).permute(0, 2, 1, 3).reshape(batch_size, Seq, dim)
return MultiHead_attn