panbaoran913 2022-03-18 21:47 采纳率: 71.4%
浏览 38
已结题

【torch】在函数内部运行的时候与取函数的内部运算测试的时候,为啥结果不一样

问题遇到的现象和发生背景

GRU+NN构建一个新的函数框架。因为torch.GRU的输入维数为3维,【序列长度,batch_size, 特征个数】。后接NN的时候,我需要的是【batch_size,序列长度】作为NN的输入结构。

问题相关代码,请勿粘贴截图

这是定义的函数

class ActorNet(nn.Module):
    def __init__(self, s_dim, a_dim):
        super(ActorNet, self).__init__()
        self.gru = nn.GRU(input_size=1, hidden_size=1,
                    num_layers =3, bias=True,
                    batch_first=False,
                    dropout=0, bidirectional=False)
        # 输入一个序列,长为4,batch为7,特征是1
        self.fc1 = nn.Linear(s_dim, 30)
        self.fc1.weight.data.normal_(0, 0.1)  # initialization of FC1
        self.out = nn.Linear(30, a_dim)
        self.out.weight.data.normal_(0, 0.1)  # initialization of OUT

    def forward(self, x):   
        x = torch.tensor(s).to(torch.float32)
        x = torch.unsqueeze(torch.transpose(x, 1,0),2)
        # print("x.shape: (%s,%s) not matching GRU" % x.shape)
        x1,h0 = self.gru(x) # x.shape= (4,7,1)
        x1 = torch.transpose(torch.squeeze(x1),0,1)
        x2 = self.fc1(x1)
        x2 = F.relu(x2)
        x3 = self.out(x2)
        Action = torch.tanh(x3)
        return Action

接下来我在notebook中检查这个函数的正确性。

from HRA import ActorNet # 导入该函数
S_dim = 4
A_dim = 1
s=np.random.normal(0,3,size=(7,4))   #这个表示数据输入的形状
AC = ActorNet(S_dim,A_dim) 
x = torch.tensor(s).to(torch.float32)
x = torch.unsqueeze(torch.transpose(x, 1,0),2)
print(x.shape)
x1,h0 = AC.gru(x) # x1.shape= (4,7,1)
x1 = torch.transpose(torch.squeeze(x1),0,1)
x2 = AC.fc1(x1)
x2 = F.relu(x2)
x3 = AC.out(x2)
Action = torch.tanh(x3) 

我根据AC的forward计算步骤,调用函数计算没有问题。但是,我单独调用的时候就有问题。

AC = ActorNet(S_dim,A_dim)
AC(s)
运行结果及报错内容

报错内容:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-31-29f2b7d0d927> in <module>
      2 #x = np.expand_dims(np.transpose(s,(1,0)),axis= 2)
      3 #x = torch.Tensor(x)
----> 4 AC(s)

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~\Desktop\潘\GDDPG\re_model\HRA.py in forward(self, x)
     26         x = torch.unsqueeze(torch.transpose(x, 1,0),2)
     27         # print("x.shape: (%s,%s) not matching GRU" % x.shape)
---> 28         x1,h0 = self.gru(x) # x.shape= (4,7,1)
     29         x1 = torch.transpose(torch.squeeze(x1),0,1)
     30         x2 = self.fc1(x1)

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~\Anaconda3\lib\site-packages\torch\nn\modules\rnn.py in forward(self, input, hx)
    848         if batch_sizes is None:
    849             result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
--> 850                              self.dropout, self.training, self.bidirectional, self.batch_first)
    851         else:
    852             result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias,

RuntimeError: expected scalar type Double but found Float
我的解答思路和尝试过的方法
我想要达到的结果
  • 写回答

0条回答 默认 最新

    报告相同问题?

    问题事件

    • 系统已结题 3月26日
    • 创建了问题 3月18日

    悬赏问题

    • ¥15 is not in the mmseg::model registry。报错,模型注册表找不到自定义模块。
    • ¥15 安装quartus II18.1时弹出此error,怎么解决?
    • ¥15 keil官网下载psn序列号在哪
    • ¥15 想用adb命令做一个通话软件,播放录音
    • ¥30 Pytorch深度学习服务器跑不通问题解决?
    • ¥15 部分客户订单定位有误的问题
    • ¥15 如何在maya程序中利用python编写领子和褶裥的模型的方法
    • ¥15 Bug traq 数据包 大概什么价
    • ¥15 在anaconda上pytorch和paddle paddle下载报错
    • ¥25 自动填写QQ腾讯文档收集表