阿根廷必胜 2024-05-06 08:47 采纳率: 87.5%
浏览 11
已结题

paddle使用中二维卷积算子较大数据时无法运行出结果


class Conv2D(nn.Layer):
    def __init__(self, kernel_size, stride=1, padding=0,
                    weight_attr=paddle.ParamAttr(initializer=nn.initializer.Constant(value=1.0))):
        super(Conv2D, self).__init__()
        self.weight = paddle.create_parameter(shape=[kernel_size,kernel_size],
                                                dtype='float32',
                                                attr=weight_attr)
        # 步长
        self.stride = stride
        # 零填充
        self.padding = padding

    def forward(self, X):
        # 零填充
        new_X = paddle.zeros([X.shape[0], X.shape[1]+2*self.padding, X.shape[2]+2*self.padding])
        new_X[:, self.padding:X.shape[1]+self.padding, self.padding:X.shape[2]+self.padding] = X
        u, v = self.weight.shape
        output_w = (new_X.shape[1] - u) // self.stride + 1
        output_h = (new_X.shape[2] - v) // self.stride + 1
        output = paddle.zeros([X.shape[0], output_w, output_h])
        for i in range(0, output.shape[1]):
            for j in range(0, output.shape[2]):
                output[:, i, j] = paddle.sum(
                    new_X[:, self.stride*i:self.stride*i+u, self.stride*j:self.stride*j+v]*self.weight,
                    axis=[1,2])
        print(output.shape)
        # return output

inputs = paddle.randn(shape=[1,150 , 150])
conv2d_padding = Conv2D(kernel_size=3, padding=1)
outputs = conv2d_padding(inputs)

img

单步调试时是这个出不来atexit_call那有问题

img

  • 写回答

4条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2024-05-06 08:47
    关注

    引自免费微信小程序:皆我百晓生

    在您提供的代码中,我注意到几个可能的问题。首先,卷积层的权重初始化为常数1.0,这可能导致数值不稳定。其次,没有添加偏置项,这可能影响网络的表现。最后,当处理较大输入时,可能会遇到内存不足的问题。

    为了优化您的代码并解决这些问题,我建议进行以下更改:

    1. 更改权重初始化方式,例如使用Xavier正则化。
    2. 添加一个可学习的偏置项。
    3. 如果仍然存在内存问题,可以尝试使用paddle.to_tensor()将随机张量转换为指定设备(CPU或GPU)上的张量。

    以下是修改后的代码:

    import paddle
    from paddle import nn
    
    class Conv2D(nn.Layer):
        def __init__(self, kernel_size, stride=1, padding=0):
            super(Conv2D, self).__init__()
            
            # 使用 Xavier 正则化初始化权重
            weight_attr = nn.initializer.XavierUniform()
            
            self.weight = paddle.create_parameter(shape=[kernel_size, kernel_size],
                                                    dtype='float32',
                                                    attr=weight_attr)
    
            # 初始化偏置
            self.bias = paddle.create_parameter(shape=[1],
                                                dtype='float32',
                                                attr=nn.initializer.Constant(value=0.0))
    
            self.stride = stride
            self.padding = padding
    
        def forward(self, X):
            # 将输入数据转移到设备(CPU 或 GPU)
            X = paddle.to_tensor(X)
    
            # 零填充
            new_X = paddle.nn.functional.pad(X, [self.padding, self.padding, self.padding, self.padding])
    
            u, v = self.weight.shape
            output_w = (new_X.shape[2] - u) // self.stride + 1
            output_h = (new_X.shape[3] - v) // self.stride + 1
            output = paddle.zeros([X.shape[0], output_w, output_h])
    
            for i in range(0, output.shape[1]):
                for j in range(0, output.shape[2]):
                    output[:, i, j] = paddle.sum(
                        new_X[:, :, self.stride * i:self.stride * i + u, self.stride * j:self.stride * j + v] * self.weight,
                        axis=[2, 3]) + self.bias
    
            print(output.shape)
            return output
    
    inputs = paddle.randn(shape=[1, 150, 150])
    conv2d_padding = Conv2D(kernel_size=3, padding=1)
    outputs = conv2d_padding(inputs)
    

    这段代码应该能正常运行。如果您仍然遇到问题,请检查您的PaddlePaddle安装、CUDA版本以及GPU驱动程序是否兼容,并确保有足够的内存来处理较大的输入。如果有任何错误消息,请提供详细的错误信息以便进一步分析。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论 编辑记录
查看更多回答(3条)

报告相同问题?

问题事件

  • 系统已结题 5月16日
  • 已采纳回答 5月8日
  • 创建了问题 5月6日

悬赏问题

  • ¥20 WPF MVVM模式 handycontrol 框架, hc:SearchBar 控件 Text="{Binding NavMenusKeyWords}" 绑定取不到值
  • ¥15 需要手写数字信号处理Dsp三个简单题 不用太复杂
  • ¥15 数字信号处理考试111
  • ¥100 关于#audobe audition#的问题,如何解决?
  • ¥15 allegro17.2生成bom表是空白的
  • ¥15 请问一下怎么打通CAN通讯
  • ¥20 如何在 rocky9.4 部署 CDH6.3.2?
  • ¥35 navicat将excel中的数据导入mysql出错
  • ¥15 rt-thread线程切换的问题
  • ¥15 高通uboot 打印ubi init err 22