WilL846 2023-10-31 23:43 采纳率: 75%
浏览 12
已结题

pytorch grid_sample问题

有一个尺寸为N x C x H x W的张量inputs,一个尺寸为 N x H x W x 2张量sample,将inputs通过F.grid_sample(inputs, samples, padding_mode = "zeros")的方式采样得到输出outputs,outputs中一些元素由于采样超出inputs边界被设为零,现有一个尺寸为N x C x H x W的张量x,我希望将张量x中与outputs中被设为零的元素相同位置的元素也设为零,该怎么做,也就是说我希望通过sample获得一个mask来将x中对应位置的元素变成零

  • 写回答

2条回答 默认 最新

  • 木头人123。 2023-11-01 09:20
    关注

    你可以使用如下步骤来达到你的目标:

    1. 创建一个全1张量,它的尺寸与inputs相同。
    2. 使用F.grid_sample()函数,将这个全1张量与samples一起采样得到一个mask。
    3. 将得到的mask张量与x进行逐元素乘法。

    在PyTorch中,你可以通过如下代码实现这个过程:

    import torch
    import torch.nn.functional as F
    
    N, C, H, W = inputs.size()
    
    # Step 1: 创建一个全1张量
    ones = torch.ones_like(inputs)
    
    # Step 2: 采样得到mask
    mask = F.grid_sample(ones, samples, padding_mode="zeros")
    
    # Step 3: 将mask与x进行逐元素乘法
    x_masked = x * mask
    

    在上述代码中,F.grid_sample()函数将全1张量onessamples一起采样,得到一个新的mask张量。这个mask张量的元素值与outputs中的元素一一对应,如果outputs中的某个元素为0(表示采样超出边界),那么mask张量中的对应元素也为0。然后,我们将mask张量与x进行逐元素乘法,得到的x_masked就是你想要的结果。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 11月9日
  • 已采纳回答 11月1日
  • 创建了问题 10月31日

悬赏问题

  • ¥15 matlab数据降噪处理,提高数据的可信度,确保峰值信号的不损失?
  • ¥15 怎么看我在bios每次修改的日志
  • ¥15 python+mysql图书管理系统
  • ¥15 Questasim Error: (vcom-13)
  • ¥15 船舶旋回实验matlab
  • ¥30 SQL 数组,游标,递归覆盖原值
  • ¥15 为什么我的数据接收的那么慢呀有没有完整的 hal 库并 代码呀有的话能不能发我一份并且我用 printf 函数显示处理之后的数据,用 debug 就不能运行了呢
  • ¥20 gitlab 中文路径,无法下载
  • ¥15 用动态规划算法均分纸牌
  • ¥30 udp socket,bind 0.0.0.0 ,如何自动选取用户访问的服务器IP来回复数据