class GRUNet(nn.Module):
def __init__(self,input_size,hidden_size,n_layers,output_size):
super(GRUNet,self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.n_layers = n_layers
self.output_size = output_size
self.gru = nn.GRU(input_size,hidden_size,n_layers,batch_first=True)
self.fc1 = nn.Sequential(
nn.Linear(hidden_size,output_size)
)
def forward(self,x):
r_out,h_n = self.gru(x,None)## None 表示初始值的hidden_state为0(gru中有两个参数,一个是输入,一个是hidden)
out = self.fc1(r_out[:,-1,:])
return out
grumodel = GRUNet(8,20,3,1).to(device)
summary(grumodel, (96,8))
为什么统计的gru参数数量为0