NLP菜鸡 2021-07-01 14:42 采纳率: 0%
浏览 38

有人把GHM loss用在NLP领域里过吗

class GHM_Loss(nn.Module):
    def __init__(self, bins, alpha):
        super(GHM_Loss, self).__init__()
        self._bins = bins
        self._alpha = alpha
        self._last_bin_count = None

    def _g2bin(self, g):
        return torch.floor(g * (self._bins - 0.0001)).long()

    def _custom_loss(self, x, target, weight):
        raise NotImplementedError

    def _custom_loss_grad(self, x, target):
        raise NotImplementedError

    def forward(self, x, target):
        g = torch.abs(self._custom_loss_grad(x, target))
        bin_idx = self._g2bin(g)
        bin_count = torch.zeros((self._bins))
        for i in range(self._bins):
            bin_count[i] = (bin_idx == i).sum().item()

        N = x.size(0)

        nonempty_bins = (bin_count > 0).sum().item()
        gd = bin_count * nonempty_bins
        gd = torch.clamp(gd, min=0.0001)
        beta = N / gd
        return self._custom_loss(x, target, beta[bin_idx[:self._bins]])


class GHMC_Loss(GHM_Loss):
    def __init__(self, bins, alpha):
        super(GHMC_Loss, self).__init__(bins, alpha)

    def _custom_loss(self, x, target, weight):
        return torch.sum(
            (torch.nn.NLLLoss(reduce=False)(torch.log(x), target)).mul(weight.to('cpu').detach())) / torch.sum(
            weight.to('cpu').detach())

    def _custom_loss_grad(self, x, target):
        x = x.cpu().detach()
        target = target.cpu()
        return torch.tensor([x[i, target[i]] for i in range(target.shape[0])]) - target

这段GHM loss的代码我想把他用在NLP来解决样本不平衡问题。 之前用focal loss的代码能直接套进去,但这个套进去发现各自bug运行不起来,有大佬知道该怎么改吗

  • 写回答

1条回答 默认 最新

  • heart_6662 2022-12-25 23:34
    关注

    GHM loss (Gradient Harmonized Single-stage Detector loss)是一种常用的目标检测损失函数,它能够很好地平衡类别不平衡的情况下的训练效果。它是在单阶段目标检测模型中使用的。

    GHM loss 是一种图像分割和目标检测领域的损失函数,我觉得并不适用于 NLP (Natural Language Processing, 自然语言处理) 领域。 NLP 领域常用的损失函数有交叉熵损失、平均绝对误差损失、平均平方误差损失等

    评论

报告相同问题?

悬赏问题

  • ¥15 C#算法问题, 不知道怎么处理这个数据的转换
  • ¥15 YoloV5 第三方库的版本对照问题
  • ¥15 请完成下列相关问题!
  • ¥15 drone 推送镜像时候 purge: true 推送完毕后没有删除对应的镜像,手动拷贝到服务器执行结果正确在样才能让指令自动执行成功删除对应镜像,如何解决?
  • ¥15 求daily translation(DT)偏差订正方法的代码
  • ¥15 js调用html页面需要隐藏某个按钮
  • ¥15 ads仿真结果在圆图上是怎么读数的
  • ¥20 Cotex M3的调试和程序执行方式是什么样的?
  • ¥20 java项目连接sqlserver时报ssl相关错误
  • ¥15 一道python难题3