ZQCJ1 2023-03-04 11:08 采纳率: 83.3%
浏览 25
已结题

tensorflow.keras训练模型预测问题

请问在使用tensorflow.keras训练模型预测时val_loss变化非常奇怪
loss是正常下降,但val_loss一开始就很低,而且一直震荡
如下图

img

最后预测结果也很差,如下图

img

我训练的模型如下
Xtrain.shape, Xtest.shape, Ytrain.shape, Ytest.shape为(2036, 60, 6) (400, 60, 6) (2036, 60, 6) (400, 60, 6)

model = keras.models.Sequential()
model.add(keras.layers.LSTM(40, input_shape=(Xtrain.shape[1:]), return_sequences=True, ))
model.add(keras.layers.Dropout(0.1))
model.add(keras.layers.LSTM(30, return_sequences=True))  # model.add(keras.layers.Dropout(0.5))
model.add(keras.layers.Dropout(0.1))
model.add(keras.layers.LSTM(40, return_sequences=True))
model.add(keras.layers.Dropout(0.1))
model.add(keras.layers.LSTM(40, return_sequences=True))
model.add(keras.layers.BatchNormalization())  # 批标准化:对一小批数据(batch)做标准化处理(使数据符合均值为0,标准差为1分布)
model.add(keras.layers.TimeDistributed(keras.layers.Dense(Ytrain.shape[2])))
model.compile(optimizer=keras.optimizers.Adam(lr=0.0001, amsgrad=True), loss='mse')  # mae: mean_absolute_error
model.summary()
history = model.fit(
        Xtrain, Ytrain,
        validation_data=(Xtest, Ytest),
        batch_size=32,
        epochs=30,
        verbose=1)

预测部分代码如下

predict = model.predict(Xtest)
predict = scalar.inverse_transform(predict[0])
Ytesting = scalar.inverse_transform(Ytest[0])
for i in range(6):
    plt.subplot(2, 3, i + 1)
    plt.plot(predict[:, i], color='blue')
    plt.plot(Ytesting[:, i], color='red')
plt.show()

请问是模型结构问题还是模型参数问题啊,亦或者训练模型的数据集有问题啊

  • 写回答

2条回答 默认 最新

  • 小虎AI实验室 人工智能领域优质创作者 2023-03-04 12:20
    关注

    这种情况可能是由于模型在训练过程中出现了过拟合的现象。过拟合通常是指模型在训练集上表现很好,但在测试集上表现不佳的情况。在训练过程中,模型过度适应了训练集的噪声和特定的样本,从而导致了 val_loss 震荡。

    要解决这个问题,可以尝试以下几种方法:

    增加训练数据量:通过增加训练数据量,可以减少模型的过拟合现象。

    添加正则化项:在模型中添加正则化项,如 L1 或 L2 正则化,可以限制模型的复杂度,减少过拟合现象。

    使用早期停止技术:在训练过程中,可以通过监控验证集上的 loss 值来确定模型是否开始出现过拟合。一旦发现过拟合现象,就可以通过早期停止来防止模型继续训练。

    减小模型复杂度:通过减小模型的复杂度,如减少层数、神经元数等,可以降低过拟合现象的发生概率。

    希望以上方法可以帮助您解决问题。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 3月13日
  • 已采纳回答 3月5日
  • 创建了问题 3月4日

悬赏问题

  • ¥20 机器学习能否像多层线性模型一样处理嵌套数据
  • ¥20 西门子S7-Graph,S7-300,梯形图
  • ¥50 用易语言http 访问不了网页
  • ¥50 safari浏览器fetch提交数据后数据丢失问题
  • ¥15 matlab不知道怎么改,求解答!!
  • ¥15 永磁直线电机的电流环pi调不出来
  • ¥15 用stata实现聚类的代码
  • ¥15 请问paddlehub能支持移动端开发吗?在Android studio上该如何部署?
  • ¥20 docker里部署springboot项目,访问不到扬声器
  • ¥15 netty整合springboot之后自动重连失效