pytorch 双向LSTM转到TensorRT(6.0.1.5)遇到[8] Assertion failed: axis >= 0 && axis < nbDims:
问题复现:
import torch
import torch.nn as nn
class BidirectionLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(BidirectionLSTM, self).__init__()
self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
self.linear = nn.Linear(hidden_size * 2, output_size)
def forward(self, input):
recurrent, _ = self.rnn(input)
output = self.linear(recurrent)
return output
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.SequenceModeling = nn.Sequential(
BidirectionLSTM(512, 256, 256),
BidirectionLSTM(256, 256, 256),
)
def forward(self, input):
output = self.SequenceModeling(input)
return output
if __name__ == '__main__':
model = Model()
model.eval()
dummy_input = torch.rand((1, 64, 512))
dummy_output = model(dummy_input)
torch_out = torch.onnx.export(model, dummy_input, "test.onnx", export_params=True, verbose=True,
input_names=["input"], output_names=["output"])
模型再netron中可视化如下:
不知如何解决?