class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, cell_size, output_size):
super().__init__()
self.hidden_size = hidden_size # 隐含状态h的大小,也即LSTM单元隐含层神经元数量
self.cell_size = cell_size # 记忆单元c的大小
# 门
self.gate = nn.Linear(input_size+hidden_size, cell_size)
self.output = nn.Linear(hidden_size, output_size)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden, cell):
# 连接输入x与h
combined = torch.cat((input, hidden), 1)
# 遗忘门
f_gate = self.sigmoid(self.gate(combined))
# 输入门
i_gate = self.sigmoid(self.gate(combined))
z_state = self.tanh(self.gate(combined))
# 输出门
o_gate = self.sigmoid(self.gate(combined))
# 更新记忆单元
cell = torch.add(torch.mul(cell, f_gate), torch.mul(z_state, i_gate))
# 更新隐藏状态h
hidden = torch.mul(self.tanh(cell), o_gate)
output = self.output(hidden)
output = self.softmax(output)
return output, hidden, cell
def initHidden(self):
return torch.zeros(1, self.hidden_size)
def initCell(self):
return torch.zeros(1, self.cell_size)
上述代码实现了LSTM单元,其中init方法中使用了self.gate,是不是存在三个门权值共享的问题?