输入是one_hot编码, 1类OK + N类 NG 。 OK固定是one_hot编码 第一位
OK和NG之间分对更重要,因此增加 自定义损失,尝试提升性能
如果设置, weight[0] =0.0 , weight[1] = 1.0 。 loss有值,但是,应用到调整网络梯度时, 网络所有层的梯度都是0 。 即,如代码示意损失,已经丢失帝都回传信息。
问题:怎样修改,才能实现带梯度回传信息的损失计算?
loss_value = self.augment_loss.loss(y,logits)#此处loss_value有值
当进一步,计算 可train参数的梯度时
grads = tape.gradient(loss_value, self.model.trainable_weights) # grads 全是0