wnljl 2021-04-23 20:13 采纳率: 0%
浏览 39

python关于神经网络请问这个代码什么意思

import scipy.integrate
import autograd.numpy as np
from autograd.extend import primitive, defvjp_argnums
from autograd import make_vjp
from autograd.misc import flatten
from autograd.builtins import tuple

odeint = primitive(scipy.integrate.odeint)


def grad_odeint_all(yt, func, y0, t, func_args, **kwargs):
    T,D = np.shape(yt)
    flat_args, unflatten = flatten(func_args)

    def flat_func(y, t, flat_args):
        return func(y, t, *unflatten(flat_args))

    def unpack(x):
        # y , vjp_y , vjp_t , vjp_args
        return x[0:D], x[D:2 * D], x[2 * D], x[2 * D + 1:]

    def augmented_dynamics(augmented_state, t, flat_args):
        # Orginal system augmented with vjp_y , vjp_t and vjp_args .
        y, vjp_y, _, _ = unpack(augmented_state)
        vjp_all, dy_dt = make_vjp(flat_func, (0, 1, 2))(y, t, flat_args)
        vjp_y, vjp_t, vjp_args = vjp_all((vjp_y))
        return np.hstack((dy_dt, vjp_y, vjp_t, vjp_args))

    def vjp_all(g, **kwargs):
        vjp_y = g[-1, :]
        vjp_t0 = 0
        time_vjp_list = []
        vjp_args = np.zeros(np.size(flat_args))

        for i in range(T - 1, 0, -1):
            # Compute e f f e c t o f moving c u r r e n t tim e .
            vjp_cur_t = np.dot(func(yt[i, :], t[i], *func_args), g[i, :])
            time_vjp_list.append(vjp_cur_t)
            vjp_t0 = vjp_t0 - vjp_cur_t

            # Run a u gme nte d s y st em b a c kw a r d s t o t h e p r e v i o u s o b s e r v a t i o n .
            aug_y0 = np.hstack((yt[i, :], vjp_y, vjp_t0, vjp_args))
            aug_ans = odeint(augmented_dynamics, aug_y0, np.array([t[i], t[i - 1]]), tuple((flat_args,)), **kwargs)
            _, vjp_y, vjp_t0, vjp_args = unpack(aug_ans[1])

            # Add g r a d i e n t f rom c u r r e n t o u t p u t .
            vjp_y = vjp_y + g[i - 1, :]

        time_vjp_list.append(vjp_t0)
        vjp_times = np.hstack(time_vjp_list)[:: -1]

        return None, vjp_y, vjp_times, unflatten(vjp_args)

    return vjp_all


def grad_argnums_wrapper(all_vjp_builder):
    def build_

  • 写回答

1条回答 默认 最新

  • 有问必答小助手 2021-04-26 14:20
    关注

    你好,我是有问必答小助手,非常抱歉,本次您提出的有问必答问题,技术专家团超时未为您做出解答

    本次提问扣除的有问必答次数,将会以问答VIP体验卡(1次有问必答机会、商城购买实体图书享受95折优惠)的形式为您补发到账户。

    ​​​​因为有问必答VIP体验卡有效期仅有1天,您在需要使用的时候【私信】联系我,我会为您补发。

    评论

报告相同问题?

悬赏问题

  • ¥35 平滑拟合曲线该如何生成
  • ¥100 c语言,请帮蒟蒻写一个题的范例作参考
  • ¥15 名为“Product”的列已属于此 DataTable
  • ¥15 安卓adb backup备份应用数据失败
  • ¥15 eclipse运行项目时遇到的问题
  • ¥15 关于#c##的问题:最近需要用CAT工具Trados进行一些开发
  • ¥15 南大pa1 小游戏没有界面,并且报了如下错误,尝试过换显卡驱动,但是好像不行
  • ¥15 自己瞎改改,结果现在又运行不了了
  • ¥15 链式存储应该如何解决
  • ¥15 没有证书,nginx怎么反向代理到只能接受https的公网网站