Clichong 2022-05-25 21:20
浏览 66
已结题

使用torch.autograd.Function自定义激活函数时,如何在父类中对子类传入参数?

这里贴上yolov5的一个高效自定义激活函数的源码:

class MemoryEfficientMish(nn.Module):
    class F(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x):
            ctx.save_for_backward(x)    # 表示forward()的结果要存起来,以后给backward()
            return x.mul(torch.tanh(F.softplus(x)))  # x * tanh(ln(1 + exp(x)))

        # grad_output是最终object对的forward()输出的导数, 也就是理解为上一层求导的结果
        # ctx是一个元祖
        @staticmethod
        def backward(ctx, grad_output):    # grad_output上一层求导的结果
            x = ctx.saved_tensors[0]       # ctx.saved_tensors得到之前forward()存的结果
            sx = torch.sigmoid(x)
            fx = F.softplus(x).tanh()
            return grad_output * (fx + x * sx * (1 - fx * fx))

    def forward(self, x):
        return self.F.apply(x)

这里我如何传入一个beta参数到子类F中去,也就是在父类MemoryEfficientMish中传入一个参数到子类F中,使得可以控制子类的forward与backward函数的返回。一个设想的伪代码实现如下:

class MemoryEfficientMish(nn.Module):

  # 可以传入参数beta,默认为1,也就是简化的版本
    def __init__(self, beta=1.):
        super().__init__()
        self.beta = beta

    class F(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x):
            ctx.save_for_backward(x)    # 表示forward()的结果要存起来,以后给backward()
            
            # 传入参数beta使得可以控制返回函数
            if self.beta != 1.0:
                 return ...
            return x.mul(torch.tanh(F.softplus(x)))  # x * tanh(ln(1 + exp(x)))

        # grad_output是最终object对的forward()输出的导数, 也就是理解为上一层求导的结果
        # ctx是一个元祖
        @staticmethod
        def backward(ctx, grad_output):    # grad_output上一层求导的结果
            x = ctx.saved_tensors[0]       # ctx.saved_tensors得到之前forward()存的结果
            sx = torch.sigmoid(x)
            fx = F.softplus(x).tanh()

            # 传入参数beta使得可以控制返回函数
            if self.beta != 1.0:
                 return ...
            return grad_output * (fx + x * sx * (1 - fx * fx))

    def forward(self, x):
        return self.F.apply(x)
  • 写回答

0条回答 默认 最新

    报告相同问题?

    问题事件

    • 系统已结题 6月2日
    • 创建了问题 5月25日

    悬赏问题

    • ¥15 如何在scanpy上做差异基因和通路富集?
    • ¥20 关于#硬件工程#的问题,请各位专家解答!
    • ¥15 关于#matlab#的问题:期望的系统闭环传递函数为G(s)=wn^2/s^2+2¢wn+wn^2阻尼系数¢=0.707,使系统具有较小的超调量
    • ¥15 FLUENT如何实现在堆积颗粒的上表面加载高斯热源
    • ¥30 截图中的mathematics程序转换成matlab
    • ¥15 动力学代码报错,维度不匹配
    • ¥15 Power query添加列问题
    • ¥50 Kubernetes&Fission&Eleasticsearch
    • ¥15 報錯:Person is not mapped,如何解決?
    • ¥15 c++头文件不能识别CDialog