阿麥Mai 2023-08-13 04:21 采纳率: 57.1%
浏览 5
已结题

基于GRU模型实现预测未来数据的问题

使用gru模型,想通过例如前几天的数据,预测后面几天的数据,为什么预测出来的值都是相同的?是不是预测代码写错了?

img

class RNNModel(nn.Module):
    """RNN模型"""

    def __init__(self, rnn_layer, output_size=1):
        super(RNNModel, self).__init__()
        self.rnn = rnn_layer  # RNN层
        self.dense = nn.Linear(rnn_layer.hidden_size, output_size)  # 全连接层
        self.state = None  # 隐藏状态

    def forward(self, x, state):
        Y, self.state = self.rnn(x, state)  # (batch_size, num_input, num_hidden)
        output = self.dense(Y.reshape(-1, Y.shape[-1]))  # (num_input*batch_size, num_output)
        return output, self.state


# 4.加载模型
model_path = "model/1.pt"
GRU_layer = nn.GRU(batch_first=True, input_size=dim_x, hidden_size=hidden_size, num_layers=num_layers)
net = model.RNNModel(rnn_layer=GRU_layer)
net.load_state_dict(torch.load(model_path))
net.to(device)


# 5.预测未来数据
net.eval()
state = None
with torch.no_grad():
    y_list, y_hat_list = [], []
    y_list.append(all_data * (max_data - min_data) + min_data)

    for i in range(1, n_train):  # 取n_train天数据进行预测
        inputs = test_x[i - 1, 0].unsqueeze(1).unsqueeze(1)
        output, state = net(inputs, state)
        pred = inputs.squeeze().cpu().numpy() * (max_data - min_data) + min_data  # 反归一化处理
        y_hat_list.append(pred)
        # print(f"第{i}天", pred.round(2))

    for i in range(n_train, predict_len + n_train + 1):  # 预测predict_len天数据
        output, state = net(inputs, state)  # 输入模型进行预测
        inputs = output.unsqueeze(1)  # 将预测结果作为下一次的输入
        pred = output.squeeze().cpu().numpy() * (max_data - min_data) + min_data  # 反归一化处理
        y_hat_list.append(pred)
        # print(f"第{i}天预测值:{pred.round(2)}")
  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2023-08-13 06:03
    关注

    【相关推荐】



    • 看下这篇博客,也许你就懂了,链接:本科生学深度学习-GRU最简单的讲解,伪代码阐述逻辑,实例展示效果
    • 除此之外, 这篇博客: GRU简介中的 1.GRU的输入输出结构 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:

      GRU的输入输出结构与普通的RNN是一样的。

      输入:t时刻的输入xtx^{t}xt,和t-1时刻的隐藏层状态ht−1h^{t-1}ht1 ,这个隐藏层状态包含了之前节点的相关信息。
      输出:t时刻隐藏节点的输出yty^{t}yt和传递给下一个节点的隐状态hth^{t}ht
      在这里插入图片描述


    如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^
    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 9月27日
  • 创建了问题 8月13日

悬赏问题

  • ¥15 植物重测序snp数据Treemix分析出现问题!
  • ¥15 怎么让当前页面只能有一人在编辑
  • ¥15 UCOSⅢ,3.0.3升级为3.0.4后程序编译成功,但是运行后死在统计任务的地方
  • ¥15 python程序长时间运行卡死,付费求解决方案
  • ¥20 VM打开不了ubuntu虚拟机,如何解决?
  • ¥15 java请求一个返回流式数据的接口,如何将流式数据直接返回前端
  • ¥15 为什么连接不了啊,配置都没问题啊
  • ¥15 c语言做一个简单的计算器,大家来看看
  • ¥15 nuxtjs3+ts 报错,急呀!
  • ¥15 matlab矩阵复数本征值排序