我在使用JAX进行机器学习中求梯度的运算时遇到了无法解决的问题。
我将我的问题简化为如下代码:
import jax
import jax.numpy as jnp
def parameter(theta):
H = jnp.ones([8,8]) + theta*jnp.eye(8)
Q = jax.scipy.linalg.eigh(H)[1]
return jnp.abs(jnp.linalg.slogdet(Q)[1])
gradient_fn = jax.grad(parameter)
theta_value = 10.0
gradient = gradient_fn(theta_value)
print("Gradient with respect to theta:", gradient)
无论输入怎样的theta_value, 结果总是nan。
根据我的尝试,问题应该出现在这一步:Q = jax.scipy.linalg.eigh(H)。
请问为什么会出现nan呢?我有应该如何纠正这个错误(因为nan会导致梯度优化无法进行)