Sawakita1122 2020-03-27 16:24 采纳率: 50%
浏览 2994
已采纳

关于keras 对模型进行训练 train_on_batch参数和模型输出的关系

在用keras+gym测试policy gradient进行小车杆平衡时模型搭建如下:

        inputs = Input(shape=(4,),name='ob_inputs')
        x = Dense(16,activation='relu')(inputs)
        x = Dense(16,activation='relu')(x)
        x = Dense(1,activation='sigmoid')(x)
        model = Model(inputs=inputs,outputs = x)

这里输出层是一个神经元,输出一个[0,1]之间的数,表示小车动作的概率
但是在代码训练过程中,模型的训练代码为:

                X = np.array(states)
                y = np.array(list(zip(actions,discount_rewards)))
                loss = self.model.train_on_batch(X,y)

这里的target data(y)是一个2维的列表数组,第一列是对应执行的动作,第二列是折扣奖励,那么在训练的时候,神经网络的输出数据和target data的维度不一致,是如何计算loss的呢?会自动去拟合y的第一列数据吗?

  • 写回答

1条回答 默认 最新

  • 关注
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

悬赏问题

  • ¥30 Matlab打开默认名称带有/的光谱数据
  • ¥50 easyExcel模板 动态单元格合并列
  • ¥15 res.rows如何取值使用
  • ¥15 在odoo17开发环境中,怎么实现库存管理系统,或独立模块设计与AGV小车对接?开发方面应如何设计和开发?请详细解释MES或WMS在与AGV小车对接时需完成的设计和开发
  • ¥15 CSP算法实现EEG特征提取,哪一步错了?
  • ¥15 游戏盾如何溯源服务器真实ip?需要30个字。后面的字是凑数的
  • ¥15 vue3前端取消收藏的不会引用collectId
  • ¥15 delphi7 HMAC_SHA256方式加密
  • ¥15 关于#qt#的问题:我想实现qcustomplot完成坐标轴
  • ¥15 下列c语言代码为何输出了多余的空格