研研研冲冲冲 2023-04-16 16:42 采纳率: 59.1%
浏览 12

video swin transformer的全连接层在哪个代码里面

video swin transformer(mmaction)的全连接层在哪个代码里面

  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2023-04-16 22:49
    关注
    • 这篇博客: 【代码解析】mmaction2: Video Swin Transformer中的 1.2 解析 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:
      • SwinTransformer3D

        • patch_embed: PatchEmbed3D
          将输入三维信号切分成多个3d-patch,patch_size默认(2,4,4),对每个patch使用3d-conv进行特征提取并降采样
          • padding:对无法被patch_size整除维度进行填零padding
          • self.proj = conv3d(3, 96, kernel_size = patch_size, stride=patch_size):对输入特征进行三维卷积,即对每个patch_size大小窗口的输入进行特征提取,每个patch_size输出一个96维特征
          • norm(optional): fllatten + transpose + layer_norm(对channel维度进行norm,即对每个patch_size的96维特征进行归一化)+transpose
      • pos_drop: nn.Drop

      • self.layers : depths [2, 2, 6, 2] 多个BasicLayer进行串联

        • BasicLayer 进一步对上层输出信号切分成多个3d-window,window_size默认(8,7,7),对patch和patch之间的特征关联进行信息提取
          • get_window_size((D,H,W), window_size=(8,7,7), shift_size=(4,3,3))
          • rearrange(x, 'b c d h w -> b d h w c')
          • self.attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) 根据输入尺度和window_size生成transformer中的mask,对非自身window的特征关联信息进行抑制
            在这里插入图片描述
        • nn.ModuleList(SwinTransformerBlock3D(for i in range(depth)])多个SwinTransformerBlock3D进行串联 (B,D,H,W,C)
          在这里插入图片描述
          • nn.LayerNorm
          • F.pad
          • torch.roll(optional)
          • x_windows = window_partition: shape (B*nW, Wd*Wh*Ww, C) window切分
          • attn_windows = self.attn(x_windows, mask=attn_mask): WindowAttention3D 对window内部进行self-attention特征提取, shape (B*nW, Wd*Wh*Ww, C)
            • nn.Linear(dim, dim * 3, bias=qkv_bias) 将输入升维三倍
            • q, k, v = qkv[0], qkv[1], qkv[2] 提取K,Q,V特征
            1. q * self.scale = head_dim ** -0.5根据head_num进行缩放,防止multi-head大小对信号量影响过大
            2. attn = q @ k.transpose(-2, -1) 内积
            • attn + relative_position_bias: relative_position_bias_table 加入位置编码(防止特征顺序对transformer模块失效,不参与学习)
            • attn.view(B_ // nW, nW, self.num_heads, N, N) + mask 加入关联特征激活/抑制mask,这里mask就是之前提取的self.attn_mask
            • self.softmax(attn) + self.attn_drop(attn) Transformer标准模块
            • x = (attn @ v) Transformer标准模块
            • self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) Transformer标准模块
            • x = shortcut + self.drop_path(x) FFN模块
      • downsample: PatchMerging 对输出特征进行重排,H和W变为1/2(不对D进行降采样),channel会变成4倍在这里插入图片描述

        • 对H和W进行间隔采样
        • norm: nn.LayerNorm
        • nn.Linear(4 * dim, 2 * dim) channel降维
      • rearrange(x, 'b d h w c -> b c d h w')

      • rearrange + norm + rearrange

      Swin-trans参数膨胀
      inflate_weights

      • patch_embed 中的conv3d选择直接膨胀初始化conv2d
      • relative_position_bias_table 两种:膨胀初始化、中心初始化
    评论

报告相同问题?

问题事件

  • 创建了问题 4月16日

悬赏问题

  • ¥15 有偿求苍穹外卖环境配置
  • ¥15 代码在keil5里变成了这样怎么办啊,文件图像也变了,
  • ¥20 Ue4.26打包win64bit报错,如何解决?(语言-c++)
  • ¥15 clousx6整点报时指令怎么写
  • ¥30 远程帮我安装软件及库文件
  • ¥15 关于#自动化#的问题:如何通过电脑控制多相机同步拍照或摄影(相机或者摄影模组数量大于60),并将所有采集的照片或视频以一定编码规则存放至规定电脑文件夹内
  • ¥20 深信服vpn-2050这台设备如何配置才能成功联网?
  • ¥15 Arduino的wifi连接,如何关闭低功耗模式?
  • ¥15 Android studio 无法定位adb是什么问题?
  • ¥15 C#连接不上服务器,