我的网络输入是和标签都是单通道的灰度图,但是输入图片中有一些点,比如像素值=0或者255的无效点,(这些无效点在gt中有的有对应的值,有的也为0)我想在训练过程中把这些点屏蔽掉以免影响网络的精度,该如何操作?
是否有办法将一个张量中的指定数据比如为0或者255的数据取消其梯度,禁止其反向传播?
我的网络输入是和标签都是单通道的灰度图,但是输入图片中有一些点,比如像素值=0或者255的无效点,(这些无效点在gt中有的有对应的值,有的也为0)我想在训练过程中把这些点屏蔽掉以免影响网络的精度,该如何操作?
是否有办法将一个张量中的指定数据比如为0或者255的数据取消其梯度,禁止其反向传播?
1. pytorch应该有某些跟踪机制可以完成你的要求,这我不是很了解。
2. 常规做法,定义一个mask用来记录Index,让其输出对应位置为0,也就是反向梯度为0.