cscs885 2022-07-20 07:03 采纳率: 100%
浏览 42
已结题

embedding 矩阵是根据什么来生成的呢

在学习Word2Vec的时候

img


会使用到一层embedding 层来使中心词的ont-hot 矩阵降维,但是我想知道 这个embedding layer里面的这个embedding 矩阵是根据什么来生成的呢? 有什么论文或者谁能解释一下原理么?

  • 写回答

1条回答 默认 最新

  • 林地宁宁 2022-07-20 07:17
    关注

    这个问题我以前也困扰过,研究半天发现结果其实特别简单,embedding 层就是一个查找表。这就是说,如果你有 10 个 token,也就是有 10 种 one-hot 编码,那么每一个 one-hot 都对应一个 embedding 结果,给他全部记录下来就好,之后靠着 BP 算法,能自动把这些 embedding 学习到。

    对应到 pytorch 的源码,更是简单,对应源码 https://github.com/pytorch/pytorch/blob/5b03ff0a09d43d721067e39da10aa23edc6997cd/aten/src/ATen/native/Embedding.cpp#L14-L29 中 14~29 行,你会发现他就一个 index_select 函数,说明 embedding 里面的矩阵就是一个查找表,根本连乘法运算都没有:

    Tensor embedding(const Tensor & weight, const Tensor & indices,
                     int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
      auto indices_arg = TensorArg(indices, "indices", 1);
      checkScalarType("embedding", indices_arg, kLong);
    
      // TODO: use tensor.index() after improving perf
      if (indices.dim() == 1) {
        return weight.index_select(0, indices);
      }
    
      auto size = indices.sizes().vec();
      for (auto d : weight.sizes().slice(1)) {
        size.push_back(d);
      }
      return weight.index_select(0, indices.reshape(-1)).view(size);
    }
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
    cscs885 2022-07-29 08:02

    感谢! 那这样我还有一个问题,就是这个“embedding层的查找表”又是如何初始化生成的呢? 因为最开始总得有一个初始化的表才能用BP算法更新这个表吧

    回复
    林地宁宁 回复 cscs885 2022-07-29 08:42

    这个可以看一看 embedding 层的 python 包装,其中的 .weight 就是查找表,事实上就是通过一个中心为 0,方差为 1 的正态分布随机取样的。还望采纳答案。

    1
    回复
    林地宁宁 回复 cscs885 2022-07-29 08:45

    可以看 embedding 的实现,我怀疑初始化的过程不同框架可能不一样,在 pytorch 中就是简单的用中心为 0,方差为 1 的正态分布随机采样得到的初始参数。答案还望采纳。

    1
    回复
编辑
预览

报告相同问题?

问题事件

  • 系统已结题 8月5日
  • 已采纳回答 7月29日
  • 创建了问题 7月20日

悬赏问题

  • ¥15 c++二叉树三种遍历问题
  • ¥15 下面三个文件分别是OFDM波形的数据,我的思路公式和我写的成像算法代码,有没有人能帮我改一改,如何解决?
  • ¥15 Ubuntu打开gazebo模型调不出来,如何解决?
  • ¥100 有chang请一位会arm和dsp的朋友解读一个工程
  • ¥50 求代做一个阿里云百炼的小实验
  • ¥20 DNS服务器所在的国家不同与你的IP地址所在国家
  • ¥15 查询优化:A表100000行,B表2000 行,内存页大小只有20页,运行时3页,设计两个表等值连接的最简单的算法
  • ¥15 led数码显示控制(标签-流程图)
  • ¥20 为什么在复位后出现错误帧
  • ¥15 结果有了,想问一下这个具体怎么输入
手机看
程序员都在用的中文IT技术交流社区

程序员都在用的中文IT技术交流社区

专业的中文 IT 技术社区,与千万技术人共成长

专业的中文 IT 技术社区,与千万技术人共成长

关注【CSDN】视频号,行业资讯、技术分享精彩不断,直播好礼送不停!

关注【CSDN】视频号,行业资讯、技术分享精彩不断,直播好礼送不停!

客服 返回
顶部