_biubiubiu_ 2021-11-04 10:13 采纳率: 0%
浏览 33

tensorflow2+ 版本梯度带断流问题

大致思路如下:
模型A 预测输入x的标签y_pred
模型B 根据输入的y_pred和真实标签y输出一个数值loss_A作为模型A 的损失
根据loss_A计算模型A的梯度,并更新模型A
更新的A重新预测x的标签,为y_pred_new
此时计算y和y_pred_new的交叉熵损失loss,更新模型B,但是在loss对模型B的求梯度时,梯度全为none

请问该如何修改下面代码才能实现这个功能?

代码如下


ModelA(x)  #  param: theta
# 网络# 
return  y_pred


ModelB(y, y_pred)  # param: beta
# 网络 # 
return z  # 该代码中z作为modelA的loss更新A


with tf.GradientTape() as tape_1:
    with tf.GrandientTape() as tape:
        y_pred = ModelA(x)
        loss_A = ModelB(y, y_pred)
    grads = tape.gradient(loss_A, ModelA.trainable_variables)
    optimizer.apply_gradients(zip(grads, ModelA.trainable_variables)
    y_pred_new = ModelA(x)
    loss = categorical_crossentropy(y, y_pred_new)
grads_1 = tape_1.gradient(loss, ModelB.trainable_variables)  #  !!!!此处出现问题,梯度全为none
optimizer.apply_gradients(zip(grads_1, ModelB.trainable_variables)

公式过程大致如下

img

  • 写回答

1条回答 默认 最新

  • 影醉阏轩窗 2021-11-04 14:49
    关注

    不愿意花时间看你的代码,仅看问题描述比较简单,直接定义AB模型,手动更新和控制梯度即可。

    评论

报告相同问题?

问题事件

  • 创建了问题 11月4日

悬赏问题

  • ¥20 指导如何跑通以下两个Github代码
  • ¥15 大家知道这个后备文件怎么删吗,为啥这些文件我只看到一份,没有后备呀
  • ¥15 C++为什么这个代码没报错运行不出来啊
  • ¥15 一道ban了很多东西的pyjail题
  • ¥15 关于#r语言#的问题:如何将生成的四幅图排在一起,且对变量的赋值进行更改,让组合的图漂亮、美观@(相关搜索:森林图)
  • ¥15 C++识别堆叠物体异常
  • ¥15 微软硬件驱动认证账号申请
  • ¥15 GPT写作提示指令词
  • ¥20 根据动态演化博弈支付矩阵完成复制动态方程求解和演化相图分析等
  • ¥15 华为超融合部署环境下RedHat虚拟机分区扩容问题