mrdongcy 2022-09-17 15:09 采纳率: 0%
浏览 47

torch.optim.Adam不能获取Parameter参数

使用torch.nn.Parameter定义参数后,使用torch.optim.Adam(centerloss.parameters())出现:
"ValueError: optimizer got an empty parameter list"
代码如下:


import torch
import torch.nn as nn

class CenterLoss(nn.Module):
    def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.use_gpu = use_gpu

        if self.use_gpu:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
        else:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

    def forward(self, x, labels):
        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat.addmm_(1, -2, x, self.centers.t())

        classes = torch.arange(self.num_classes).long()
        if self.use_gpu: classes = classes.cuda()
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, self.num_classes))

        dist = distmat * mask.float()
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size

        return loss

if __name__ == "__name__":
    centerloss = CenterLoss()
    torch.optim.Adam(centerloss.parameters())
    

请帮忙解答一下这是什么原因呀

  • 写回答

2条回答 默认 最新

  • Elwin Wong 2022-09-17 15:15
    关注

    你的CenterLoss类中并没有定义网络层啊,所以就没有要优化的参数,所以报错传递一个空的参数列表。

    评论

报告相同问题?

问题事件

  • 创建了问题 9月17日

悬赏问题

  • ¥15 HLs设计手写数字识别程序编译通不过
  • ¥15 Stata外部命令安装问题求帮助!
  • ¥15 从键盘随机输入A-H中的一串字符串,用七段数码管方法进行绘制。提交代码及运行截图。
  • ¥15 TYPCE母转母,插入认方向
  • ¥15 如何用python向钉钉机器人发送可以放大的图片?
  • ¥15 matlab(相关搜索:紧聚焦)
  • ¥15 基于51单片机的厨房煤气泄露检测报警系统设计
  • ¥15 Arduino无法同时连接多个hx711模块,如何解决?
  • ¥50 需求一个up主付费课程
  • ¥20 模型在y分布之外的数据上预测能力不好如何解决