深度学习中的知识蒸馏算法loss函数计算哪种是正确的?
hard_loss=nn.CrossEntropyLoss()
soft_loss=nn.KLDivLoss(reduction="batchmean")
loss=hard_loss(student_out,label)
ditillation_loss=soft_loss(F.softmax(student_out/T,dim=1),F.softmax(teacher_output/T,dim=1))
方式一:
loss_total = loss*alpha+ditillation_loss*(1-alpha)
方式二:
loss_total = loss*alpha+ditillation_loss(T*T*2)*(1-alpha)