export failure: ONNX export failed: Couldn't export operator aten::unfold
1条回答 默认 最新
关注 不知道你这个问题是否已经解决, 如果还没有解决的话:- 以下回答来自chatgpt:
答案:
这个问题的根本原因是 ONNX 无法导出
aten::unfold
运算符。为了解决这个问题,我们可以修改 PyTorch 模型,以使用 ONNX 能够导出的运算符来代替aten::unfold
。我们可以使用以下步骤解决这个问题:- 检查模型中是否有使用
aten::unfold
运算符。我们可以使用以下代码来查找模型中的所有运算符及其类型:
```python import torch from torch.onnx import register_custom_op_symbolic
def print_model_ops(model): for name, module in model.named_modules(): for i, (name, op) in enumerate(module.named_parameters(recurse=False)): print(f"{name} ({module.class.name}.{op.class.name}): {op.cpu().contiguous().view(-1).numpy()}")
for i, (name, op) in enumerate(module.named_buffers(recurse=False)): print(f"{name} ({module.__class__.__name__}.{op.__class__.__name__}): {op.cpu().contiguous().view(-1).numpy()}") for i, (name, op) in enumerate(module.named_children()): print(f"{name} ({module.__class__.__name__}):") print(op) for i, (name, op) in enumerate(module.named_modules()): print(f"{name} ({module.__class__.__name__}):") print(op) if hasattr(module, "op"): print(f"{name} ({module.__class__.__name__}.{module.op.__class__.__name__}): {module.op}") # Register custom symbolic function for operator register_custom_op_symbolic("mynamespace::myop", my_custom_symbolic_linear)
model = torch.load("mymodel.pt") print_model_ops(model) ```
如果这个函数输出了
aten::unfold
,表示模型中使用了这个运算符。- 修改模型,以使用 ONNX 支持的运算符代替
aten::unfold
。如果我们需要aten::unfold
,我们可以手动实现这个函数。以下是一个简单的实现:
```python def my_unfold(input, kernel_size, dilation=1, padding=0, stride=1): # Calcuate output shape B, C, D, H, W = input.size() DH, DW = tuple(kernel_size) PD, PH, PW = tuple(padding) SH, SW = tuple(stride) OH = (D + 2 * PD - DH - (DH-1) * (dilation - 1)) // SH + 1 OW = (W + 2 * PW - DW - (DW-1) * (dilation - 1)) // SW + 1
# Pad input input_pad = torch.nn.functional.pad(input, (PW, PW, PH, PH, PD, PD)) # Reshape input to allow 2D convolution input_2d = input_pad.view(B*C*D, 1, H+2*PH, W+2*PW) # Create convolution kernel for each output pixel idx = torch.arange(OH*OW).reshape(OH, OW) idh = idx // OW idw = idx % OW kh = (idh*SH).reshape(-1, 1, 1, 1) + torch.arange(DH).reshape(1, 1, -1, 1).to(idh.device) kw = (idw*SW).reshape(-1, 1, 1, 1) + torch.arange(DW).reshape(1, 1, 1, -1).to(idw.device) kernel = input_2d[:, :, kh, kw] # Do 2D convolution output_2d = torch.nn.functional.conv2d(input_2d, kernel, groups=B*C*D) # Reshape output to original size output = output_2d.view(B, C, OH, OW) return output
```
我们需要将所有的
aten::unfold
替换为my_unfold
。以下是一个简单的脚本,可以自动找到模型中的所有aten::unfold
运算符,并将它们替换为my_unfold
:```python import torch from torch.onnx import SymbolicShapeFinder
class ReplaceUnfold(torch.nn.Module): def init(self): super(ReplaceUnfold, self).init()
def forward(self, x): return my_unfold(x)
def replace_unfold(model): # Register custom symbolic function for my_unfold def my_unfold_symbolic(g, input, kernel_size, dilation, padding, stride): return g.op("mynamespace::my_unfold", input, kernel_size, dilation, padding, stride)
register_custom_op_symbolic("mynamespace::my_unfold", my_unfold_symbolic) # Find all aten::unfold ops in the model symbolic_shape_finder = SymbolicShapeFinder() symbolic_shape_finder.check_masks = True symbolic_shape_finder.find_shapes(model, torch.randn((1, 3, 224, 224))) for node in symbolic_shape_finder.nodes: if node.kind() == "aten::unfold": unfold_node = node unfold_name = unfold_node.outputs()[0].debugName() unfold_inputs = unfold_node.inputs() # Replace aten::unfold with my_unfold my_unfold_node = graph.create("mynamespace::my_unfold", [*unfold_inputs[0:1], *unfold_inputs[1].node().inputs()[1:], *unfold_inputs[2:]], 1) my_unfold_name = my_unfold_node.outputs()[0].debugName() unfold_node.outputs()[0].replaceAllUsesWith(my_unfold_node.outputs()[0]) graph.eraseUnusedNodeOutputNames() my_unfold_node.outputs()[0].copyMetadata(unfold_node.outputs()[0]) my_unfold_node.moveAfter(unfold_node) graph.eraseOutput(unfold_node.outputs()[0]) graph.insertOutput(my_unfold_node.outputs()[0], 0) graph.node(unfold_node).destroy()
model = torch.load("mymodel.pt") replace_unfold(model) torch.onnx.export(model, (torch.randn((1, 3, 224, 224)),), "mymodel.onnx") ```
- 导出修改后的模型,并验证是否成功。我们可以使用以下代码来加载并检查导出的 ONNX 模型:
```python import onnx
model = onnx.load("mymodel.onnx") onnx.checker.check_model(model) ```
- 检查模型中是否有使用
如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^解决 无用评论 打赏 举报- 以下回答来自chatgpt:
悬赏问题
- ¥15 爬虫爬取网站的一些信息
- ¥15 关于vue2中methods使用call修改this指向的问题
- ¥15 idea自动补全键位冲突
- ¥15 请教一下写代码,代码好难
- ¥15 iis10中如何阻止别人网站重定向到我的网站
- ¥15 滑块验证码移动速度不一致问题
- ¥15 Utunbu中vscode下cern root工作台中写的程序root的头文件无法包含
- ¥15 麒麟V10桌面版SP1如何配置bonding
- ¥15 Marscode IDE 如何预览新建的 HTML 文件
- ¥15 K8S部署二进制集群过程中calico一直报错