love tor 2022-08-10 19:30 采纳率: 100%
浏览 49
已结题

MaskedLinear(nn.Linear)含义以及使用

MaskedLinear(nn.Linear)含义以及使用
#代码
class MaskedLinear(nn.Linear):
def init(self, in_features, out_features, relation_file, func, bias=True):
super(MaskedLinear, self).init(in_features, out_features, bias)

    mask = self.readRelationFromFile(relation_file, func)
    self.register_buffer('mask', mask)
    self.iter = 0

def forward(self, input):
    masked_weight = self.weight * self.mask
    return F.linear(input, masked_weight, self.bias)

def readRelationFromFile(self, relation_file, func):
    mask = []
    with open(relation_file, 'r') as f:
        for line in f:
            l = [int(x) for x in line.strip().split(',')[OUT_nodes[func]:]]
            assert len(l) == OUT_nodes[func]*2
            for item in l:
                assert item == 1 or item == 0  # relation 只能为0或者1
            mask.append(l)
    return Variable(torch.Tensor(mask))
  • 写回答

1条回答 默认 最新

  • 林地宁宁 2022-08-10 19:55
    关注

    mask就是“遮掩”,简单来说就是让一些weight强制变为0,不参与训练过程。具体效用跟你的使用环境有关。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 系统已结题 8月29日
  • 已采纳回答 8月21日
  • 创建了问题 8月10日

悬赏问题

  • ¥20 怎么用dlib库的算法识别小麦病虫害
  • ¥15 华为ensp模拟器中S5700交换机在配置过程中老是反复重启
  • ¥15 java写代码遇到问题,求帮助
  • ¥15 uniapp uview http 如何实现统一的请求异常信息提示?
  • ¥15 有了解d3和topogram.js库的吗?有偿请教
  • ¥100 任意维数的K均值聚类
  • ¥15 stamps做sbas-insar,时序沉降图怎么画
  • ¥15 买了个传感器,根据商家发的代码和步骤使用但是代码报错了不会改,有没有人可以看看
  • ¥15 关于#Java#的问题,如何解决?
  • ¥15 加热介质是液体,换热器壳侧导热系数和总的导热系数怎么算