背景:在用python写核心算法为四阶龙格库塔法的软件(该软件用于数值法求解微分方程组,实际就是四阶龙格库塔法)时想用numba的jit加速,但是程序具有一定的动态性(比如求解的微分方程由用户自己输入,原来采用sympy库计算(但这个库不能被jit加速)
现在已经要ai写了一份代码 但是由于软件包版本不兼容,pycharm报错帮改改代码,我的numba库是0.59.1
import numba as nb
from numba import njit
from numba.core import ir
from numba.core.ir_utils import guard, find_topo_order
from numba import types
def generate_function(expression):
# 创建函数签名
sig = nb.float64(nb.float64)
# 创建函数的 IR 表示
func_ir = ir.FunctionIR(name="dynamic_func", arg_types=[nb.float64], return_type=nb.float64)
func_ir.blocks.append(ir.Block())
block = func_ir.blocks[0]
# 获取函数参数
x_var = block.add_arg("x", nb.float64)
# 创建表达式计算的 IR 表示
call_node = ir.Expr.binop(x_var, ir.Const(2., loc=None), ir.Add, loc=None)
other_node = ir.Const(1., loc=None)
expr_node = ir.Expr.binop(call_node, other_node, ir.Add, loc=None)
result_var = block.add_var(name="result", loc=None)
code = ir.Assign(expr_node, result_var, loc=None)
block.body.append(code)
# 创建返回语句的 IR 表示
return_node = ir.Return(value=result_var, loc=None)
block.body.append(return_node)
# 构建函数并进行 Numba JIT 编译
typing_ctx = nb.core.typing.Context()
target_ctx = nb.core.target.create_target_context("cpu")
typemap, return_type, locals = guard(find_topo_order(func_ir.blocks))
dispatcher = nb.core.dispatcher.Dispatcher(typemap)
pipeline_passes = nb.core.compiler.CompilerPassBuilder.default_passes(
typingctx=typing_ctx, targetctx=target_ctx
)
pipeline = nb.core.compiler.CompilerPipeline(
nb.core.compiler.DefaultPassBuilder.define_nopython_pipeline(typingctx=typing_ctx, targetctx=target_ctx),
nb.core.compiler.DefaultPassBuilder.define_nopython_pipeline_after_monomorphization(typingctx=typing_ctx, targetctx=target_ctx),
)
lowered = pipeline.run(func_ir, None, types.float64, types.float64)
typed_func = nb.core.compiler._Lower.lower_function(typing_ctx, target_ctx, lowered)
lifted = nb.core.compiler._PreLower(typing_ctx).run(typed_func)
optimized = nb.core.compiler._Optimize(typing_ctx, target_ctx).run(lifted)
func = optimized.library.get_function(optimized.fndesc.llvm_func_name)
return dynamic_func
# 生成动态函数并进行 Numba JIT 加速
expression = 'x ** 2 + 2 * x + 1'
dynamic_func = generate_function(expression)
# 在其他函数中使用动态生成的函数
def my_function():
result = dynamic_func(5.)
print(result)
# 调用并加速执行
my_function()