weixin_53329734 2021-09-12 16:12 采纳率: 78.2%
浏览 170
已结题

'FocalLoss' object has no attribute 'backward'

代码段为:

img


错误提示为:

img


关于focal loss的内容为
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):

def __init__(self, class_num=200, alpha=None, gamma=2, size_average=True):
    super(FocalLoss, self).__init__()
    if alpha is None:
        self.alpha = Variable(torch.ones(class_num, 1))
    else:
        if isinstance(alpha, Variable):
            self.alpha = alpha
        else:
            self.alpha = Variable(alpha)
    self.gamma = gamma
    self.class_num = class_num
    self.size_average = size_average

def forward(self, inputs, targets):
    N = inputs.size(0)
    C = inputs.size(1)
    P = F.softmax(inputs)

    class_mask = inputs.data.new(N, C).fill_(0)
    class_mask = Variable(class_mask)
    ids = targets.view(-1, 1)
    class_mask.scatter_(1, ids.data, 1.)
    #print(class_mask)


    if inputs.is_cuda and not self.alpha.is_cuda:
        self.alpha = self.alpha.cuda()
    alpha = self.alpha[ids.data.view(-1)]

    probs = (P*class_mask).sum(1).view(-1,1)

    log_p = probs.log()
    #print('probs size= {}'.format(probs.size()))
    #print(probs)

    batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
    #print('-----bacth_loss------')
    #print(batch_loss)


    if self.size_average:
        loss = batch_loss.mean()
    else:
        loss = batch_loss.sum()
    return loss

如何让focal loss支持back word呢

  • 写回答

0条回答 默认 最新

    报告相同问题?

    问题事件

    • 系统已结题 9月20日
    • 创建了问题 9月12日

    悬赏问题

    • ¥15 vhdl+MODELSIM
    • ¥20 simulink中怎么使用solve函数?
    • ¥30 dspbuilder中使用signalcompiler时报错Error during compilation: Fitter failed,求解决办法
    • ¥15 gwas 分析-数据质控之过滤稀有突变中出现的问题
    • ¥15 没有注册类 (异常来自 HRESULT: 0x80040154 (REGDB_E_CLASSNOTREG))
    • ¥15 知识蒸馏实战博客问题
    • ¥15 用PLC设计纸袋糊底机送料系统
    • ¥15 simulink仿真中dtc控制永磁同步电机如何控制开关频率
    • ¥15 用C语言输入方程怎么
    • ¥15 网站显示不安全连接问题