MaskedLinear(nn.Linear)含义以及使用
#代码
class MaskedLinear(nn.Linear):
def init(self, in_features, out_features, relation_file, func, bias=True):
super(MaskedLinear, self).init(in_features, out_features, bias)
mask = self.readRelationFromFile(relation_file, func)
self.register_buffer('mask', mask)
self.iter = 0
def forward(self, input):
masked_weight = self.weight * self.mask
return F.linear(input, masked_weight, self.bias)
def readRelationFromFile(self, relation_file, func):
mask = []
with open(relation_file, 'r') as f:
for line in f:
l = [int(x) for x in line.strip().split(',')[OUT_nodes[func]:]]
assert len(l) == OUT_nodes[func]*2
for item in l:
assert item == 1 or item == 0 # relation 只能为0或者1
mask.append(l)
return Variable(torch.Tensor(mask))