def forward(self, x):
with torch.no_grad():
x = self.vae.encode(x).latent_dist.sample().mul_(0.15)
for layer in self.layers:
x = layer(x)
with torch.no_grad():
x = self.vae.decode(x / 0.18215).sample
x.requires_grad_(True)
return x
以上是一个forward函数,使用VAE来进行encode和decode,但是我发现训练时梯度范数grad_norm从一开始就是0,但是损失又在慢慢下降,这是为什么?如果不加上VAE,梯度范数就是正常的,如下:
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x