_ ∏_∏ _ 2023-07-13 22:49 采纳率: 50%
浏览 54
已结题

gradients返回值一直是

gradients返回值一直是[None,]

  def net_init(self,state,modelIndex):
        c, h, w = self.input_dim
        self.predict = Sequential([
          tf.keras.layers.Conv2D(32,8,4,activation='relu',input_shape=(c, h, w),data_format="channels_first"),
          tf.keras.layers.Conv2D(64, 4, 2, activation='relu',padding="VALID"),
          tf.keras.layers.Conv2D(64, 3, 1, activation='relu'),
          tf.keras.layers.Flatten(),
          tf.keras.layers.Dense(512, input_shape=(3136,), activation='relu'),
          tf.keras.layers.Dense(self.output_dim, input_shape=(512,), activation=None)])

        self.target = Sequential([
            tf.keras.layers.Conv2D(32, 8, 4, activation='relu', input_shape=(4, 84, 84), data_format="channels_first"),
            tf.keras.layers.Conv2D(64, 4, 2, activation='relu', padding="VALID"),
            tf.keras.layers.Conv2D(64, 3, 1, activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(512, input_shape=(3136,), activation='relu'),
            tf.keras.layers.Dense(self.output_dim, input_shape=(512,), activation=None)])

        if modelIndex==1:
            # self.predict.summary()
            return self.predict(state)
        if modelIndex==2:
            self.target.summary()
            return self.target(state)


@tf.function
    def train_step(self, states, actions):
        with tf.GradientTape() as tape:
            tape.watch(states)
            loss = tf.keras.losses.huber(states, actions)

        gradients = tape.gradient(loss, self.predict.trainable_variables)
        # gradients = [tf.clip_by_norm(gradient, 10) for gradient in gradients]

        self.optimizer.apply_gradients(zip(gradients, self.predict.trainable_variables))
        return loss

求解答

完整代码:https://github.com/huxxshadow/test

  • 写回答

11条回答 默认 最新

  • CSDN专家-sinJack 2023-07-20 10:01
    关注
    获得3.20元问题酬金

    Debug调试看下每一步的值变化情况,方便排查问题

    评论

报告相同问题?

问题事件

  • 系统已结题 7月21日
  • 赞助了问题酬金20元 7月13日
  • 创建了问题 7月13日

悬赏问题

  • ¥15 HLs设计手写数字识别程序编译通不过
  • ¥15 Stata外部命令安装问题求帮助!
  • ¥15 从键盘随机输入A-H中的一串字符串,用七段数码管方法进行绘制。提交代码及运行截图。
  • ¥15 TYPCE母转母,插入认方向
  • ¥15 如何用python向钉钉机器人发送可以放大的图片?
  • ¥15 matlab(相关搜索:紧聚焦)
  • ¥15 基于51单片机的厨房煤气泄露检测报警系统设计
  • ¥15 路易威登官网 里边的参数逆向
  • ¥15 Arduino无法同时连接多个hx711模块,如何解决?
  • ¥50 需求一个up主付费课程