关于不同尺寸的输入,导致的报错问题,想问下可能是由什么造成的,该怎么解决。
RuntimeError: shape '[1, 96, 13, 7, 13, 7]' is invalid for input of size 884736
更详细的见下面说明
当输入为224*224时,正常输出
input = torch.rand(1, 3, 224, 224)
print(m1.max_vit_base_224().forward(input)[0].shape)
print(m1.max_vit_base_224().forward(input)[1].shape)
print(m1.max_vit_base_224().forward(input)[2].shape)
print(m1.max_vit_base_224().forward(input)[3].shape)
#结果为:
#torch.Size([1, 96, 56, 56])
#torch.Size([1, 192, 28, 28])
#torch.Size([1, 384, 14, 14])
#torch.Size([1, 768, 7, 7])
但是当输入换成别的尺寸,比如384*384,就会报如下错误:
Traceback (most recent call last):
File "G:/swinmask/Swin-Transformer-Object-Detection/somettttrrrryy.py", line 10, in <module>
print(m1.max_vit_base_224().forward(input)[0].shape)
File "G:\swinmask\Swin-Transformer-Object-Detection\mmdet\models\backbones\maxvit.py", line 688, in forward
output = stage(output)
File "D:\Program\anaconda3\envs\swin_det\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "G:\swinmask\Swin-Transformer-Object-Detection\mmdet\models\backbones\maxvit.py", line 560, in forward
output = self.blocks(input)
File "D:\Program\anaconda3\envs\swin_det\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Program\anaconda3\envs\swin_det\lib\site-packages\torch\nn\modules\container.py", line 119, in forward
input = module(input)
File "D:\Program\anaconda3\envs\swin_det\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "G:\swinmask\Swin-Transformer-Object-Detection\mmdet\models\backbones\maxvit.py", line 492, in forward
output = self.grid_transformer(self.block_transformer(self.mb_conv(input)))
File "D:\Program\anaconda3\envs\swin_det\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "G:\swinmask\Swin-Transformer-Object-Detection\mmdet\models\backbones\maxvit.py", line 399, in forward
input_partitioned = self.partition_function(input, self.grid_window_size)
File "G:\swinmask\Swin-Transformer-Object-Detection\mmdet\models\backbones\maxvit.py", line 133, in window_partition
windows = input.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1])
RuntimeError: shape '[1, 96, 13, 7, 13, 7]' is invalid for input of size 884736