love tor 2022-08-10 19:30 采纳率: 100%
浏览 49
已结题

MaskedLinear(nn.Linear)含义以及使用

MaskedLinear(nn.Linear)含义以及使用
#代码
class MaskedLinear(nn.Linear):
def init(self, in_features, out_features, relation_file, func, bias=True):
super(MaskedLinear, self).init(in_features, out_features, bias)

    mask = self.readRelationFromFile(relation_file, func)
    self.register_buffer('mask', mask)
    self.iter = 0

def forward(self, input):
    masked_weight = self.weight * self.mask
    return F.linear(input, masked_weight, self.bias)

def readRelationFromFile(self, relation_file, func):
    mask = []
    with open(relation_file, 'r') as f:
        for line in f:
            l = [int(x) for x in line.strip().split(',')[OUT_nodes[func]:]]
            assert len(l) == OUT_nodes[func]*2
            for item in l:
                assert item == 1 or item == 0  # relation 只能为0或者1
            mask.append(l)
    return Variable(torch.Tensor(mask))
  • 写回答

1条回答 默认 最新

  • 林地宁宁 2022-08-10 19:55
    关注

    mask就是“遮掩”,简单来说就是让一些weight强制变为0,不参与训练过程。具体效用跟你的使用环境有关。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 系统已结题 8月29日
  • 已采纳回答 8月21日
  • 创建了问题 8月10日

悬赏问题

  • ¥15 树莓派5怎么用camera module 3啊
  • ¥20 java在应用程序里获取不到扬声器设备
  • ¥15 echarts动画效果的问题,请帮我添加一个动画。不要机器人回答。
  • ¥15 Attention is all you need 的代码运行
  • ¥15 一个服务器已经有一个系统了如果用usb再装一个系统,原来的系统会被覆盖掉吗
  • ¥15 使用esm_msa1_t12_100M_UR50S蛋白质语言模型进行零样本预测时,终端显示出了sequence handled的进度条,但是并不出结果就自动终止回到命令提示行了是怎么回事:
  • ¥15 前置放大电路与功率放大电路相连放大倍数出现问题
  • ¥80 部署运行web自动化项目
  • ¥15 腾讯云如何建立同一个项目中物模型之间的联系
  • ¥30 VMware 云桌面水印如何添加