
我在开发基于mlir的onnx编译器时遇到了这种pattern,想融合成1个op,这个op有2个输出。
使用RewritePattern可以方便灵活的抓住优化时机。
但看了文档和源码,RewritePattern只能针对单个根节点,不太理解mlir为什么要这么设计,是特性还是局限啊?
如果设置根节点为左边,但在rewrite时对左右都进行替换,是RewritePattern正确的用法吗?还是会有什么隐患?

关注引自免费微信小程序:皆我百晓生
在MLIR(Multi-Level Intermediate Representation)中,RewritePattern主要用于定义模式匹配和替换规则,通常用于代码优化阶段。每个RewritePattern实例对应一个特定的操作模式(Operation Pattern),它会在中间表示(Intermediate Representation, IR)遍历过程中寻找匹配项,并执行相应的替换操作。
在您的例子中,您想要合并三个操作:1x1x64、Concat 和 ×1×128 成一个新的操作,这个新操作有两个输出。然而,MLIR中的RewritePattern设计上确实倾向于处理单个根操作,这是为了保持模式匹配的简单性和效率。这并不意味着您不能实现多输出操作的合并,只是可能需要稍微变通一下方法。
一种可行的方法是创建两个不同的RewritePattern,分别处理输入和输出的连接。这样,您可以首先应用一个模式来合并前两个操作,然后应用第二个模式来连接结果到第三个操作。这里是一个简单的伪代码示例:
class ConcatAndMultiplyPattern(RewritePattern):
def __init__(self, context):
super().__init__("concat_and_multiply", OpMatchFinderCategory.LLVM_IR_BUILDER, context)
def match_and_rewrite(self, operation, rewriter):
# 检查是否有1x1x64和Concat操作
if len(operation.getOperands()) != 2 or not isinstance(operation.getOperand(0), mlir.opr.Conv) or not isinstance(operation, mlir.opr.Concatenate):
return False
# 创建新的乘法操作
multiply_op = mlir.opr.Mul.create(
operation.location,
operation.getResult().type,
operation.getOperand(0),
operation.getOperand(1)
)
# 替换原操作
rewriter.replaceOp(operation, multiply_op)
return True
class ConnectOutputPattern(RewritePattern):
def __init__(self, context):
super().__init__("connect_output", OpMatchFinderCategory.LLVM_IR_BUILDER, context)
def match_and_rewrite(self, operation, rewriter):
# 检查是否有乘法操作和×1×128操作
if not isinstance(operation, mlir.opr.Mul) or not isinstance(operation.getUsers()[0], mlir.opr.Cos):
return False
# 连接乘法操作的结果到Cos操作的输入
cos_op = operation.getUsers()[0]
cos_op.setOperand(0, operation.getResult())
return True
要使用这些模式,您需要将它们注册到一个RewritePatternSet中,然后在一个驱动程序中运行applyPatternsAndFoldGreedily函数,如下所示:
# 假设你已经有一个mlir模块module
patterns = RewritePatternSet(module.context)
patterns.add(ConcatAndMultiplyPattern(module.context))
patterns.add(ConnectOutputPattern(module.context))
# 应用模式并折叠操作
rewriter = module.context.createRewriter()
with mlir.PassManager.parse("convert-linalg-to-loops") as pm:
pm.addPass(mlir.CreatePassPipeline([patterns]))
pm.run(module)
# 打印修改后的模块
print(str(module))
这种方法虽然不是最直观的,但它允许您利用MLIR的RewritePattern系统完成多输出操作的合并。需要注意的是,这只是一个简化的例子,实际情况下可能需要处理更复杂的逻辑和错误检查。