原博文链接地址:https://blog.csdn.net/Sebastien23/article/details/80574918
其中有不少代码完全看不太懂,想来这里求教下各位大神~~
class Sequence(nn.Module):
def __init__(self):
super(Sequence,self).__init__()
self.lstm1 = nn.LSTMCell(1,51)
self.lstm2 = nn.LSTMCell(51,51)
self.linear = nn.Linear(51,1)
#上面三行代码是设置网络结构吧?为什么用的是LSTMCell,而不是LSTM??
def forward(self,inputs,future= 0):
#这里的前向传播名称必须是forward,而不能随意更改??因为后面的模型调用过程中,并没有看到该方法的实现
outputs = []
h_t = torch.zeros(inputs.size(0),51)
c_t = torch.zeros(inputs.size(0),51)
h_t2 = torch.zeros(inputs.size(0),51)
c_t2 = torch.zeros(inputs.size(0),51)
#下面的代码中,LSTM的原理是要求三个输入:前一层的细胞状态、隐藏层状态和当前层的数据输入。这里却只有2个输入??
for i,input_t in enumerate(inputs.chunk(inputs.size(1),dim =1)):
h_t,c_t = self.lstm1(input_t,(h_t,c_t))
h_t2,c_t2 = self.lstm2(h_t,(h_t2,c_t2))
output = self.linear(h_t2)
outputs +=[output]
for i in range(future):
h_t,c_t = self.lstm1(output,(h_t,c_t))
h_t2,c_t2 = self.lstm2(h_t,(h_t2,c_t2))
output = self.linear(h_t2)
outputs +=[output]
#下面将所有的输出在第一维上相拼接,并剪除维度为2的数据??目的是什么?
outputs = torch.stack(outputs,1).squeeze(2)
return outputs