看informer源代码时,attn.py里的ProbAttention有这么一段:
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
# Q [B, H, L, D]
B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape
# calculate the sampled Q_K
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
其中对K_expand进行切片索引的操作,即K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]这一句,不明白为什么要对torch.arange(L_Q)进行维度扩展的操作。通过实验发现
K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]的维度是(B,H,L_Q,sample_k,E)
K_expand[:, :, :, index_sample, :]的维度是(B,H,L_Q,L_Q,sample_k,E)
K_expand[:,:,torch.arange(L_Q),index_sample,:]会报错,出现indexing tensors could not be broadcast together with shapes [L_Q], [L_Q, sample_k]的问题
推测torch的张量索引是通过广播的机制来完成的,但是具体的实现机制依旧不清楚,十分困惑,希望有懂的人能解答!
