m0_49298253 2023-05-26 20:55 采纳率: 0%
浏览 15

请问下大家可变形注意力怎么可视化热图

N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in

    value = self.value_proj(input_flatten)
    if input_padding_mask is not None:
        value = value.masked_fill(input_padding_mask[..., None], float(0))
    value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
    sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
    attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
    attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
    # N, Len_q, n_heads, n_levels, n_points, 2
    if reference_points.shape[-1] == 2:
        offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
        sampling_locations = reference_points[:, :, None, :, None, :] \
                             + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
    elif reference_points.shape[-1] == 4:
        sampling_locations = reference_points[:, :, None, :, None, :2] \
                             + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
    else:
        raise ValueError(
            'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
    output = MSDeformAttnFunction.apply(
        value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
    output = self.output_proj(output)
    return output

如上是可变形注意力的前传函数,请问下怎么可视化attention_weights形成如下效果,这个和detr好像不同

img

  • 写回答

1条回答 默认 最新

      报告相同问题?

      相关推荐 更多相似问题

      问题事件

      • 修改了问题 5月26日
      • 创建了问题 5月26日

      悬赏问题

      • ¥15 消除字符串,求最短字符串长度
      • ¥20 有人做基于集员滤波的异常值处理相关的内容吗?(语言-matlab)
      • ¥30 matlab编程,用chatGPT帮助,但给出的code总是报错。
      • ¥15 离线安装VS2017出现报错
      • ¥50 opengl2怎么将梯形的纹理映射在矩形上面不变形
      • ¥15 起终点不同的tsp旅行商问题
      • ¥15 博途V16变频器CU320-2pn版本为2.34的gsd文件
      • ¥15 Nginx服务器配置django的channels实现即时聊天
      • ¥50 esp32作为主站基于modbus读取从站mcu的数据。
      • ¥15 【提问】VBA实现跨表格查找满足多条件的数据并提取