过去的猴子 2024-04-12 15:31 采纳率: 0%
浏览 7

mlir RewritePattern为什么不支持匹配多个root op?

img


我在开发基于mlir的onnx编译器时遇到了这种pattern,想融合成1个op,这个op有2个输出。
使用RewritePattern可以方便灵活的抓住优化时机。
但看了文档和源码,RewritePattern只能针对单个根节点,不太理解mlir为什么要这么设计,是特性还是局限啊?
如果设置根节点为左边,但在rewrite时对左右都进行替换,是RewritePattern正确的用法吗?还是会有什么隐患?

  • 写回答

5条回答 默认 最新

  • 专家-赤兔[在线] 优质创作者: 编程框架技术领域 2024-04-12 15:31
    关注

    引自免费微信小程序:皆我百晓生

    在MLIR(Multi-Level Intermediate Representation)中,RewritePattern主要用于定义模式匹配和替换规则,通常用于代码优化阶段。每个RewritePattern实例对应一个特定的操作模式(Operation Pattern),它会在中间表示(Intermediate Representation, IR)遍历过程中寻找匹配项,并执行相应的替换操作。

    在您的例子中,您想要合并三个操作:1x1x64、Concat 和 ×1×128 成一个新的操作,这个新操作有两个输出。然而,MLIR中的RewritePattern设计上确实倾向于处理单个根操作,这是为了保持模式匹配的简单性和效率。这并不意味着您不能实现多输出操作的合并,只是可能需要稍微变通一下方法。

    一种可行的方法是创建两个不同的RewritePattern,分别处理输入和输出的连接。这样,您可以首先应用一个模式来合并前两个操作,然后应用第二个模式来连接结果到第三个操作。这里是一个简单的伪代码示例:

    class ConcatAndMultiplyPattern(RewritePattern):
        def __init__(self, context):
            super().__init__("concat_and_multiply", OpMatchFinderCategory.LLVM_IR_BUILDER, context)
    
        def match_and_rewrite(self, operation, rewriter):
            # 检查是否有1x1x64和Concat操作
            if len(operation.getOperands()) != 2 or not isinstance(operation.getOperand(0), mlir.opr.Conv) or not isinstance(operation, mlir.opr.Concatenate):
                return False
    
            # 创建新的乘法操作
            multiply_op = mlir.opr.Mul.create(
                operation.location,
                operation.getResult().type,
                operation.getOperand(0),
                operation.getOperand(1)
            )
    
            # 替换原操作
            rewriter.replaceOp(operation, multiply_op)
    
            return True
    
    class ConnectOutputPattern(RewritePattern):
        def __init__(self, context):
            super().__init__("connect_output", OpMatchFinderCategory.LLVM_IR_BUILDER, context)
    
        def match_and_rewrite(self, operation, rewriter):
            # 检查是否有乘法操作和×1×128操作
            if not isinstance(operation, mlir.opr.Mul) or not isinstance(operation.getUsers()[0], mlir.opr.Cos):
                return False
    
            # 连接乘法操作的结果到Cos操作的输入
            cos_op = operation.getUsers()[0]
            cos_op.setOperand(0, operation.getResult())
    
            return True
    

    要使用这些模式,您需要将它们注册到一个RewritePatternSet中,然后在一个驱动程序中运行applyPatternsAndFoldGreedily函数,如下所示:

    # 假设你已经有一个mlir模块module
    patterns = RewritePatternSet(module.context)
    patterns.add(ConcatAndMultiplyPattern(module.context))
    patterns.add(ConnectOutputPattern(module.context))
    
    # 应用模式并折叠操作
    rewriter = module.context.createRewriter()
    with mlir.PassManager.parse("convert-linalg-to-loops") as pm:
        pm.addPass(mlir.CreatePassPipeline([patterns]))
        pm.run(module)
    
    # 打印修改后的模块
    print(str(module))
    

    这种方法虽然不是最直观的,但它允许您利用MLIR的RewritePattern系统完成多输出操作的合并。需要注意的是,这只是一个简化的例子,实际情况下可能需要处理更复杂的逻辑和错误检查。

    评论 编辑记录

报告相同问题?

问题事件

  • 创建了问题 4月12日