想问一下这段代码中的print的*号是什么意思哇,print那一行代码可以解释一下嘛?没看懂【0】的是啥意思?求解答~
def my_init(m):
if type(m) == nn.Linear:
print("Init", *[(name, param.shape)
for name, param in m.named_parameters()][0])
nn.init.uniform_(m.weight, -10, 10)
m.weight.data *= m.weight.data.abs() >= 5
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 3))
X = torch.rand(size=(2, 4))
net(X)
net.apply(my_init)
net[0].weight[:2]