cscs885 2022-07-20 15:03 采纳率: 100%
浏览 37
已结题

embedding 矩阵是根据什么来生成的呢

在学习Word2Vec的时候

img


会使用到一层embedding 层来使中心词的ont-hot 矩阵降维,但是我想知道 这个embedding layer里面的这个embedding 矩阵是根据什么来生成的呢? 有什么论文或者谁能解释一下原理么?

  • 写回答

1条回答 默认 最新

  • 林地宁宁 2022-07-20 15:17
    关注

    这个问题我以前也困扰过,研究半天发现结果其实特别简单,embedding 层就是一个查找表。这就是说,如果你有 10 个 token,也就是有 10 种 one-hot 编码,那么每一个 one-hot 都对应一个 embedding 结果,给他全部记录下来就好,之后靠着 BP 算法,能自动把这些 embedding 学习到。

    对应到 pytorch 的源码,更是简单,对应源码 https://github.com/pytorch/pytorch/blob/5b03ff0a09d43d721067e39da10aa23edc6997cd/aten/src/ATen/native/Embedding.cpp#L14-L29 中 14~29 行,你会发现他就一个 index_select 函数,说明 embedding 里面的矩阵就是一个查找表,根本连乘法运算都没有:

    Tensor embedding(const Tensor & weight, const Tensor & indices,
                     int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
      auto indices_arg = TensorArg(indices, "indices", 1);
      checkScalarType("embedding", indices_arg, kLong);
    
      // TODO: use tensor.index() after improving perf
      if (indices.dim() == 1) {
        return weight.index_select(0, indices);
      }
    
      auto size = indices.sizes().vec();
      for (auto d : weight.sizes().slice(1)) {
        size.push_back(d);
      }
      return weight.index_select(0, indices.reshape(-1)).view(size);
    }
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 系统已结题 8月6日
  • 已采纳回答 7月29日
  • 创建了问题 7月20日

悬赏问题

  • ¥20 机器学习能否像多层线性模型一样处理嵌套数据
  • ¥20 西门子S7-Graph,S7-300,梯形图
  • ¥50 用易语言http 访问不了网页
  • ¥50 safari浏览器fetch提交数据后数据丢失问题
  • ¥15 matlab不知道怎么改,求解答!!
  • ¥15 永磁直线电机的电流环pi调不出来
  • ¥15 用stata实现聚类的代码
  • ¥15 请问paddlehub能支持移动端开发吗?在Android studio上该如何部署?
  • ¥20 docker里部署springboot项目,访问不到扬声器
  • ¥15 netty整合springboot之后自动重连失效