snapdragon623 2023-05-07 18:25 采纳率: 33.3%
浏览 28
已结题

pytorch张量的索引的机制是什么?

看informer源代码时,attn.py里的ProbAttention有这么一段:

    def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
        # Q [B, H, L, D]
        B, H, L_K, E = K.shape
        _, _, L_Q, _ = Q.shape

        # calculate the sampled Q_K
        K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
        index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q
        K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
        Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()

其中对K_expand进行切片索引的操作,即K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]这一句,不明白为什么要对torch.arange(L_Q)进行维度扩展的操作。通过实验发现
K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]的维度是(B,H,L_Q,sample_k,E)
K_expand[:, :, :, index_sample, :]的维度是(B,H,L_Q,L_Q,sample_k,E)
K_expand[:,:,torch.arange(L_Q),index_sample,:]会报错,出现indexing tensors could not be broadcast together with shapes [L_Q], [L_Q, sample_k]的问题
推测torch的张量索引是通过广播的机制来完成的,但是具体的实现机制依旧不清楚,十分困惑,希望有懂的人能解答!

  • 写回答

1条回答 默认 最新

  • 爱晚乏客游 2023-05-08 12:38
    关注

    报错是下面的torch.matmul这里乘法对应的维度不匹配报错无法广播(mxn的矩阵要和nxk的矩阵才能做乘法),而不是切片这里报错

    img

    至于你说的这里是用的是列表切片法,具体切片效果可以看这篇文章。总之这里的切片就是要保证下面的矩阵乘法维度数对应上。

    https://blog.csdn.net/qq_43923588/article/details/107974187

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论 编辑记录

报告相同问题?

问题事件

  • 系统已结题 5月16日
  • 已采纳回答 5月8日
  • 创建了问题 5月7日