2301_79944969 2024-03-21 15:57 采纳率: 50%
浏览 13

在动态创建可被numba中jit加速的函数时遇到numba版本问题

背景:在用python写核心算法为四阶龙格库塔法的软件(该软件用于数值法求解微分方程组,实际就是四阶龙格库塔法)时想用numba的jit加速,但是程序具有一定的动态性(比如求解的微分方程由用户自己输入,原来采用sympy库计算(但这个库不能被jit加速)

现在已经要ai写了一份代码 但是由于软件包版本不兼容,pycharm报错帮改改代码,我的numba库是0.59.1

import numba as nb
from numba import njit
from numba.core import ir
from numba.core.ir_utils import guard, find_topo_order
from numba import types

def generate_function(expression):
    # 创建函数签名
    sig = nb.float64(nb.float64)

    # 创建函数的 IR 表示
    func_ir = ir.FunctionIR(name="dynamic_func", arg_types=[nb.float64], return_type=nb.float64)
    func_ir.blocks.append(ir.Block())
    block = func_ir.blocks[0]

    # 获取函数参数
    x_var = block.add_arg("x", nb.float64)

    # 创建表达式计算的 IR 表示
    call_node = ir.Expr.binop(x_var, ir.Const(2., loc=None), ir.Add, loc=None)
    other_node = ir.Const(1., loc=None)
    expr_node = ir.Expr.binop(call_node, other_node, ir.Add, loc=None)
    result_var = block.add_var(name="result", loc=None)
    code = ir.Assign(expr_node, result_var, loc=None)
    block.body.append(code)

    # 创建返回语句的 IR 表示
    return_node = ir.Return(value=result_var, loc=None)
    block.body.append(return_node)

    # 构建函数并进行 Numba JIT 编译
    typing_ctx = nb.core.typing.Context()
    target_ctx = nb.core.target.create_target_context("cpu")
    typemap, return_type, locals = guard(find_topo_order(func_ir.blocks))
    dispatcher = nb.core.dispatcher.Dispatcher(typemap)
    pipeline_passes = nb.core.compiler.CompilerPassBuilder.default_passes(
        typingctx=typing_ctx, targetctx=target_ctx
    )
    pipeline = nb.core.compiler.CompilerPipeline(
        nb.core.compiler.DefaultPassBuilder.define_nopython_pipeline(typingctx=typing_ctx, targetctx=target_ctx),
        nb.core.compiler.DefaultPassBuilder.define_nopython_pipeline_after_monomorphization(typingctx=typing_ctx, targetctx=target_ctx),
    )
    lowered = pipeline.run(func_ir, None, types.float64, types.float64)
    typed_func = nb.core.compiler._Lower.lower_function(typing_ctx, target_ctx, lowered)
    lifted = nb.core.compiler._PreLower(typing_ctx).run(typed_func)
    optimized = nb.core.compiler._Optimize(typing_ctx, target_ctx).run(lifted)
    func = optimized.library.get_function(optimized.fndesc.llvm_func_name)

    return dynamic_func

# 生成动态函数并进行 Numba JIT 加速
expression = 'x ** 2 + 2 * x + 1'
dynamic_func = generate_function(expression)

# 在其他函数中使用动态生成的函数
def my_function():
    result = dynamic_func(5.)
    print(result)

# 调用并加速执行
my_function()


  • 写回答

2条回答 默认 最新

  • 码农阿豪@新空间代码工作室 Java领域优质创作者 2024-03-21 15:58
    关注
    让阿豪来帮你解答,本回答参考chatgpt3.5编写提供,如果还有疑问可以评论或留言
    在你给出的代码中,有一些不支持的语法和函数调用,导致程序无法正常运行。下面是你需要修改的内容:
    1. 替换 "from numba.core import ir" 为 "from numba.ir import ir_typeof"。
    2. 将 "def generate_function(expression)" 中的 "# 获取函数参数" 处的 "x_var = block.add_arg('x', nb.float64)" 改为 "x_var = ir.Var(name='x', type=nb.float64)"。
    3. 在 "my_function" 函数中,将 "result = dynamic_func(5.)" 改为 "result = dynamic_func(5.)()",因为 "dynamic_func" 是一个函数对象,需要加括号才能调用。
    4. 增加 "from numba.targets.registry import CPUDispatcher",该语句用于从 Numba 编译器中获取 CPU 编译器的注册表。 下面是修改后的代码:
    import numba as nb
    from numba import njit
    from numba.ir import FunctionIR, Var, Block, Expr, Return, Const
    from numba.ir_utils import guard, find_topo_order
    from numba.targets.registry import CPUDispatcher
    from numba import types
    def generate_function(expression):
        # 创建函数签名
        sig = nb.float64(nb.float64)
        # 创建函数的 IR 表示
        func_ir = FunctionIR(name="dynamic_func", arg_types=[nb.float64], return_type=nb.float64)
        func_ir.blocks.append(Block())
        block = func_ir.blocks[0]
        # 获取函数参数
        x_var = Var(name="x", type=nb.float64)
        # 创建表达式计算的 IR 表示
        call_node = Expr.binop(x_var, Const(2.0), op=ir_typeof('+'))
        other_node = Const(1.0)
        expr_node = Expr.binop(call_node, other_node, op=ir_typeof('+'))
        result_var = block.add_var(name="result", loc=None)
        code = Expr.assign(expr_node, result_var)
        block.body.append(code)
        # 创建返回语句的 IR 表示
        return_node = Return(value=result_var, loc=None)
        block.body.append(return_node)
        # 构建函数并进行 Numba JIT 编译
        typing_ctx = nb.core.typing.Context()
        target_ctx = nb.core.target.create_target_context('cpu')
        typemap, return_type, locals = guard(find_topo_order(func_ir.blocks))
        dispatcher = CPUDispatcher(typemap)
        pipeline_passes = nb.core.compiler.CompilerPassBuilder.default_passes(
            typingctx=typing_ctx, targetctx=target_ctx
        )
        pipeline = nb.core.compiler.CompilerPipeline(pipeline_passes)
        lowered = pipeline.run(func_ir, None, types.float64, types.float64)
        typed_func = nb.core.compiler._Lower.lower_function(typing_ctx, target_ctx, lowered)
        lifted = nb.core.compiler._PreLower(typing_ctx).run(typed_func)
        optimized = nb.core.compiler._Optimize(typing_ctx, target_ctx).run(lifted)
        func = optimized.library.get_function(optimized.fndesc.llvm_func_name)
        return func
    # 生成动态函数并进行 Numba JIT 加速
    expression = 'x ** 2 + 2 * x + 1'
    dynamic_func = generate_function(expression)
    # 在其他函数中使用动态生成的函数
    def my_function():
        result = dynamic_func(5.0)
        print(result)
    # 调用并加速执行
    my_function()
    
    评论

报告相同问题?

问题事件

  • 创建了问题 3月21日

悬赏问题

  • ¥15 AD9910输出波形向上偏移,波谷不为0V
  • ¥15 淘宝自动下单XPath自动点击插件无法点击特定<span>元素,如何解决?
  • ¥15 曙光1620-g30服务器安装硬盘后 看不到硬盘
  • ¥15 抖音直播广场scheme
  • ¥15 为什么我明明有这个文件调试器还显示错误?
  • ¥15 软件工程用例图的建立(相关搜索:软件工程用例图|画图)
  • ¥15 如何在arcgis中导出拓扑关系表
  • ¥15 处理数据集文本挖掘代码
  • ¥15 matlab2017
  • ¥15 在vxWorks下TCP/IP编程,总是connect()报错,连接服务器失败: errno = 0x41