_biubiubiu_ 2021-11-05 17:02 采纳率: 0%
浏览 16
已结题

tensorflow2.x 梯度带求导断流问题

大致思路如下:
模型A 预测输入 x 的标签 y_pred
模型B 根据输入的 y_pred 和真实标签y输出一个数值 z 作为模型A 的损失
根据z计算模型A的梯度,并更新模型A
更新的A重新预测 x 的标签记为 y_pred_new
此时计算 y_pred 和 y_pred_new 的交叉熵损失loss,更新模型B,但是在loss对模型B的求梯度时,梯度全为none

代码如下

ModelA(x)  #  param: θ , 预测输入 x 的标签 y_pred
# 网络# 
return  y_pred
 
ModelB(y, y_pred)  # param: beta,根据输入的 y_pred 和真实标签y输出一个数值 z 作为模型A 的损失
# 网络 # 
return z  
 
with tf.GradientTape() as tape_1:
    with tf.GrandientTape() as tape:
        y_pred = ModelA(x)                                   
        z = ModelB(y, y_pred)                                
    grads = tape.gradient(z, ModelA.trainable_variables)
    optimizer.apply_gradients(zip(grads, ModelA.trainable_variables)   # 根据z计算模型A的梯度,并更新模型A 
    y_pred_new = ModelA(x)                                  # 更新的A重新预测 x 的标签记为y_pred_new
    loss = categorical_crossentropy(y, y_pred_new)        # 计算 y_pred 和 y_pred_new 的交叉熵损失loss,用以更新模型B
grads_1 = tape_1.gradient(loss, ModelB.trainable_variables)   #  !!!!此处出现问题,梯度全为none
optimizer.apply_gradients(zip(grads_1, ModelB.trainable_variables)

公式过程如下图
推测问题在于③位置上的对θ更新时求导,因为grads=tape.gradient()求出来是tensor,相当于βx变成了tensor,不是variable了,导致⑤位置求导的时候无法对β求导,我该如何解决这个问题?

img

  • 写回答

0条回答 默认 最新

    报告相同问题?

    问题事件

    • 系统已结题 11月13日
    • 创建了问题 11月5日

    悬赏问题

    • ¥20 指导如何跑通以下两个Github代码
    • ¥15 大家知道这个后备文件怎么删吗,为啥这些文件我只看到一份,没有后备呀
    • ¥15 C++为什么这个代码没报错运行不出来啊
    • ¥15 一道ban了很多东西的pyjail题
    • ¥15 关于#r语言#的问题:如何将生成的四幅图排在一起,且对变量的赋值进行更改,让组合的图漂亮、美观@(相关搜索:森林图)
    • ¥15 C++识别堆叠物体异常
    • ¥15 微软硬件驱动认证账号申请
    • ¥15 GPT写作提示指令词
    • ¥20 根据动态演化博弈支付矩阵完成复制动态方程求解和演化相图分析等
    • ¥15 华为超融合部署环境下RedHat虚拟机分区扩容问题