这里贴上yolov5的一个高效自定义激活函数的源码:
class MemoryEfficientMish(nn.Module):
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x) # 表示forward()的结果要存起来,以后给backward()
return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
# grad_output是最终object对的forward()输出的导数, 也就是理解为上一层求导的结果
# ctx是一个元祖
@staticmethod
def backward(ctx, grad_output): # grad_output上一层求导的结果
x = ctx.saved_tensors[0] # ctx.saved_tensors得到之前forward()存的结果
sx = torch.sigmoid(x)
fx = F.softplus(x).tanh()
return grad_output * (fx + x * sx * (1 - fx * fx))
def forward(self, x):
return self.F.apply(x)
这里我如何传入一个beta参数到子类F中去,也就是在父类MemoryEfficientMish中传入一个参数到子类F中,使得可以控制子类的forward与backward函数的返回。一个设想的伪代码实现如下:
class MemoryEfficientMish(nn.Module):
# 可以传入参数beta,默认为1,也就是简化的版本
def __init__(self, beta=1.):
super().__init__()
self.beta = beta
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x) # 表示forward()的结果要存起来,以后给backward()
# 传入参数beta使得可以控制返回函数
if self.beta != 1.0:
return ...
return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
# grad_output是最终object对的forward()输出的导数, 也就是理解为上一层求导的结果
# ctx是一个元祖
@staticmethod
def backward(ctx, grad_output): # grad_output上一层求导的结果
x = ctx.saved_tensors[0] # ctx.saved_tensors得到之前forward()存的结果
sx = torch.sigmoid(x)
fx = F.softplus(x).tanh()
# 传入参数beta使得可以控制返回函数
if self.beta != 1.0:
return ...
return grad_output * (fx + x * sx * (1 - fx * fx))
def forward(self, x):
return self.F.apply(x)