本人希望将pytorch LSTM模型转换为ONNX,但报错,代码如下:
import torch
import torch.nn as nn
from torch.nn import Module, LSTM, Linear
rnn = nn.LSTM(input_size=10,hidden_size=20,num_layers=2)
inputs = torch.randn(5,3,10)
h0 = torch.randn(2,3,20)
c0 = torch.randn(2,3,20)
num_directions=1
output,(h_n,c_n) = rnn(inputs,(h0,c0))
linear = Linear(20,2)
input = output
y = linear(input)
import torch.onnx as onnx
onnx.export(rnn,(inputs,(h0,c0)),'xxx.onnx')