斌-Bang 2024-08-17 15:06 采纳率: 14.3%
浏览 15

unet神经网络处理一维信号数据,上采样部分和跳跃连接部分一直报错,该怎么修改?


代码和图片如下:

import torch
from torch import nn
from torch.nn import functional as F


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=7, padding=3, bias=False)
        self.relu = nn.ReLU()
    def forward(self, x):
        return self.relu(self.conv(x))


class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownSample, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=7, stride=4, padding=3, bias=False)
    def forward(self, x):
        return self.conv(x)


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels):
        super(UpSample, self).__init__()
        self.up_conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size=4, stride=2, padding=1,
                                          output_padding=1, bias=False)
        # 上卷积操作后,特征图的通道数变为 out_channels
        # 拼接跳跃连接后,总通道数为 out_channels + skip_channels
        self.conv_block = ConvBlock((out_channels + skip_channels), out_channels)

    def forward(self, x, skip_connection):
        x = self.up_conv(x)
        skip_connection = F.interpolate(skip_connection, size=x.shape[2:], mode='nearest')
        x = torch.cat((x, skip_connection), dim=1)  # 在这里拼接上卷积结果和跳跃连接
        return self.conv_block(x)

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.c1 = ConvBlock(3, 8)
        self.enc1 = ConvBlock(8, 8)
        self.d1 = DownSample(8, 8)
        self.enc2 = ConvBlock(8, 11)
        self.d2 = DownSample(11, 11)
        self.enc3 = ConvBlock(11, 16)
        self.d3 = DownSample(16, 16)
        self.enc4 = ConvBlock(16, 22)
        self.d4 = DownSample(22, 22)
        self.enc5 = ConvBlock(22, 32)

        self.dec1 = UpSample(32, 22,22)
        self.enc6 = ConvBlock(44, 22)
        self.dec2 = UpSample(22, 16,16)
        self.enc7 = ConvBlock(32, 16)
        self.dec3 = UpSample(16, 11,11)
        self.enc8 = ConvBlock(22, 11)
        self.dec4 = UpSample(11, 8,8)
        self.enc9 = ConvBlock(16, 8)


        self.final_conv = nn.Conv1d(8, 1, kernel_size=1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        c1 = self.c1(x)
        enc1 = self.enc1(c1)
        d1 = self.d1(enc1)
        enc2 = self.enc2(d1)
        d2 = self.d2(enc2)
        enc3 = self.enc3(d2)
        d3 = self.d3(enc3)
        enc4 = self.enc4(d3)
        d4 = self.d4(enc4)
        enc5 =self.enc5(d4)

        dec1 = self.dec1(enc5, enc4)
        enc6 = self.enc6(dec1)
        dec2 = self.dec2(enc6, enc3)
        enc7 = self.enc6(dec2)
        dec3 = self.dec3(enc7, enc2)
        enc8 = self.enc6(dec3)
        dec4 = self.dec4(enc8, enc1)
        enc9 = self.enc6(dec4)

        final = self.final_conv(dec4)
        return self.softmax(final)


if __name__ == '__main__':
    x = torch.randn(1, 3, 3001)
    net = UNet()
    output = net(x).squeeze(1)
    print(output.shape)
![img](https://img-mid.csdnimg.cn/release/static/image/mid/ask/a48b02806d3c47a6bdce6498ede2c2b7.jpeg "#left")

  • 写回答

2条回答 默认 最新

  • 一直免费一直爽 2024-08-17 15:07
    关注

    以下回复参考:皆我百晓生券券喵儿等免费微信小程序作答:

    根据您提供的代码,我发现在UpSample类的forward方法中,您使用了torch.cat((x, skip_connection), dim=1)来拼接上卷积结果和跳跃连接。这里的问题是,拼接后的通道数可能超过了out_channels + skip_channels的限制。为了解决这个问题,您可以在拼接之前将跳跃连接的通道数调整为out_channels

    修改后的UpSample类如下:

    class UpSample(nn.Module):
        def __init__(self, in_channels, out_channels, skip_channels):
            super(UpSample, self).__init__()
            self.up_conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size=4, stride=2, padding=1,
                                              output_padding=1, bias=False)
            # 上卷积操作后,特征图的通道数变为 out_channels
            # 拼接跳跃连接后,总通道数为 out_channels + skip_channels
            self.conv_block = ConvBlock((out_channels + skip_channels), out_channels)
    
        def forward(self, x, skip_connection):
            x = self.up_conv(x)
            skip_connection = F.interpolate(skip_connection, size=x.shape[2:], mode='nearest')
            # 调整跳跃连接的通道数为 out_channels
            skip_connection = nn.Conv1d(skip_connection.shape[1], x.shape[1], kernel_size=1)(skip_connection)
            x = torch.cat((x, skip_connection), dim=1)  # 在这里拼接上卷积结果和跳跃连接
            return self.conv_block(x)
    

    这样修改后,您的UNet网络应该可以正常运行了。

    评论

报告相同问题?

问题事件

  • 创建了问题 8月17日

悬赏问题

  • ¥20 白日门传奇少一个启动区服和启动服务器的快捷键,东西都是全的 , 他们说套一个出来就行了 但我就是弄不好,谁看看,
  • ¥100 如何用js写一个游戏云存档
  • ¥15 ansys fluent计算闪退
  • ¥15 有关wireshark抓包的问题
  • ¥15 需要写计算过程,不要写代码,求解答,数据都在图上
  • ¥15 向数据表用newid方式插入GUID问题
  • ¥15 multisim电路设计
  • ¥20 用keil,写代码解决两个问题,用库函数
  • ¥50 ID中开关量采样信号通道、以及程序流程的设计
  • ¥15 U-Mamba/nnunetv2固定随机数种子