问题遇到的现象和发生背景
GRU+NN构建一个新的函数框架。因为torch.GRU的输入维数为3维,【序列长度,batch_size, 特征个数】。后接NN的时候,我需要的是【batch_size,序列长度】作为NN的输入结构。
问题相关代码,请勿粘贴截图
这是定义的函数
class ActorNet(nn.Module):
def __init__(self, s_dim, a_dim):
super(ActorNet, self).__init__()
self.gru = nn.GRU(input_size=1, hidden_size=1,
num_layers =3, bias=True,
batch_first=False,
dropout=0, bidirectional=False)
# 输入一个序列,长为4,batch为7,特征是1
self.fc1 = nn.Linear(s_dim, 30)
self.fc1.weight.data.normal_(0, 0.1) # initialization of FC1
self.out = nn.Linear(30, a_dim)
self.out.weight.data.normal_(0, 0.1) # initialization of OUT
def forward(self, x):
x = torch.tensor(s).to(torch.float32)
x = torch.unsqueeze(torch.transpose(x, 1,0),2)
# print("x.shape: (%s,%s) not matching GRU" % x.shape)
x1,h0 = self.gru(x) # x.shape= (4,7,1)
x1 = torch.transpose(torch.squeeze(x1),0,1)
x2 = self.fc1(x1)
x2 = F.relu(x2)
x3 = self.out(x2)
Action = torch.tanh(x3)
return Action
接下来我在notebook中检查这个函数的正确性。
from HRA import ActorNet # 导入该函数
S_dim = 4
A_dim = 1
s=np.random.normal(0,3,size=(7,4)) #这个表示数据输入的形状
AC = ActorNet(S_dim,A_dim)
x = torch.tensor(s).to(torch.float32)
x = torch.unsqueeze(torch.transpose(x, 1,0),2)
print(x.shape)
x1,h0 = AC.gru(x) # x1.shape= (4,7,1)
x1 = torch.transpose(torch.squeeze(x1),0,1)
x2 = AC.fc1(x1)
x2 = F.relu(x2)
x3 = AC.out(x2)
Action = torch.tanh(x3)
我根据AC的forward计算步骤,调用函数计算没有问题。但是,我单独调用的时候就有问题。
AC = ActorNet(S_dim,A_dim)
AC(s)
运行结果及报错内容
报错内容:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-31-29f2b7d0d927> in <module>
2 #x = np.expand_dims(np.transpose(s,(1,0)),axis= 2)
3 #x = torch.Tensor(x)
----> 4 AC(s)
~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
~\Desktop\潘\GDDPG\re_model\HRA.py in forward(self, x)
26 x = torch.unsqueeze(torch.transpose(x, 1,0),2)
27 # print("x.shape: (%s,%s) not matching GRU" % x.shape)
---> 28 x1,h0 = self.gru(x) # x.shape= (4,7,1)
29 x1 = torch.transpose(torch.squeeze(x1),0,1)
30 x2 = self.fc1(x1)
~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
~\Anaconda3\lib\site-packages\torch\nn\modules\rnn.py in forward(self, input, hx)
848 if batch_sizes is None:
849 result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
--> 850 self.dropout, self.training, self.bidirectional, self.batch_first)
851 else:
852 result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias,
RuntimeError: expected scalar type Double but found Float