问题遇到的现象和发生背景
以前打印网络模型,都是用summary函数来打印,但是这次改进的网络报错,不过直接用print(model)能打印出来,这是怎么回事
也用过其他博主提出的解决方法,如下
#修改前
summary[m_key]["input_shape"] = list(input[0].size())
#修改后
if isinstance(input[0], torch.Tensor):
summary[m_key]["input_shape"] = list(input[0].size())
elif isinstance(input[0], list):
summary[m_key]["input_shape"] = list(np.array(input[0]).shape)
不过这个针对input出错,而我报错是在output那里,所以一直没有找到解决问题的办法
问题相关代码
from TransUnet import *
from torchsummary import summary
model = get_transNet(2)
print(model)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print(model)
summary(model, (1, 512, 512))
运行结果及报错内容
在用print(model)可以正常打印网络
但是用summary却不能正常打印模型结构
我想要达到的结果
可以使用summary打印网络结构