AngryEcho 2021-11-23 21:35
浏览 45
已结题

有哪个损失函数可以单纯的计算两个张量(图像)之间的误差,并且不要求Height和Width维度相同?或者有哪个方法可以重塑(扩展)需要计算梯度的张量?

在做 时间序列预测 / 未来帧预测的时候,我遇到了一个问题,那就是从NN输出的我预测的图像(比如第0.1秒的帧)和原本第0.1的真实图像做loss计算时,我不知道该如何选择损失函数。我尝试了2种损失函数:MSELoss和CrossEntropy,但都会出现问题,下面是具体情况:
我的训练代码

```python
    for epoch in range(50):
        print('epoch {}'.format(epoch + 1))
        train_loss = 0.
        train_acc = 0.
        for batch_Group, batch_target in train_loader:
            for Group, target in zip(batch_Group, batch_target):
                inputs_20, target_1 = Variable(Group).cuda(), Variable(target).cuda()
                prediction = model(inputs_20)
                print(prediction.shape)
                print(target_1.shape)

                for i in range(BATCH_SIZE):
                    print(prediction[i].shape)
                    print(target_1.shape)
                    loss = loss_func(prediction[i], target_1)
                    optimizer.zero_grad() 
                    loss.backward()  # 计算梯度/反向传播
                    optimizer.step()  # 更新网络参数

我的预测帧是【 3, 684,76】,原图像的帧是【3,686,76】
在我尝试第一种nn.CrossEntropy时,遇到了一个报错

img

我理解报错原因,预测帧作为input的【3, 684,76】分别代表batch_size、number of classes 和 图像维度

如果要计算,target只允许有2个维度(我的pytorch是1.9.1),分别是batch_size 和 图像维度

img

但我并不是做图像分类,我只是做图像预测,没有类别可言,更没有类别的可能性之言,所以我无法使用Crossentropy

在我尝试第二种nn.MSELoss时,遇到的报错是:

img

同样,我也理解报错原因,因为维度不相同
并且我还知道pytorch的 broadcasting机制,如果我的target是【3】或者【3,1,1】,我也可以解决这个bug,但我觉得这会影响我的误差计算
所以我现在需要一个损失函数,可以单纯的计算两张图像之间的误差(可以是像素级的差异),并且不要求Height和Width完全一致,也就是【channels,H,W】中的H和W
或者,你也可以告诉我如何将NN输出的【3,672,64】通过某个tensor的方法扩展为【3,686,76】,因为我使用其他扩展tensor的方法不可以。比如reshape会告诉我元素数不够,expand会告诉我只能扩展起始维度为1。而tensor.resize_又告诉我计算梯度的张量不可以重塑!
  • 写回答

0条回答 默认 最新

    报告相同问题?

    问题事件

    • 系统已结题 12月1日
    • 创建了问题 11月23日

    悬赏问题

    • ¥20 matlab计算中误差
    • ¥15 对于相关问题的求解与代码
    • ¥15 ubuntu子系统密码忘记
    • ¥15 信号傅里叶变换在matlab上遇到的小问题请求帮助
    • ¥15 保护模式-系统加载-段寄存器
    • ¥15 电脑桌面设定一个区域禁止鼠标操作
    • ¥15 求NPF226060磁芯的详细资料
    • ¥15 使用R语言marginaleffects包进行边际效应图绘制
    • ¥20 usb设备兼容性问题
    • ¥15 错误(10048): “调用exui内部功能”库命令的参数“参数4”不能接受空数据。怎么解决啊