def focal_loss(alpha=0.25, gamma=2.):
""" focal loss used for train positive/negative samples rate out
of balance, improve train performance
"""
def focal_loss_calc(y_true, y_pred):
positive = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
negative = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
return -(alpha*K.pow(1.-positive, gamma)*K.log(positive) +
(1-alpha)*K.pow(negative, gamma)*K.log(1.-negative))
return focal_loss_calc
self.keras_model.compile(optimizer=optimizer, loss=dice_focal_loss, metrics=[ mean_iou, dice_loss, focal_loss()])
上面的focal loss 开始还是挺正常的,随着训练过程逐渐减小大0.025左右,然后就突然变成inf。何解