在应用CNN进行张量融合时报错:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x18240 and 10x18240)
模型代码如下
class MyCNN(nn.Module):
def __init__(self):
super(MyCNN, self).__init__()
self.conv1 = nn.Conv2d(6, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.fc1 = nn.Linear(10,18240)
self.fc2 = nn.Linear(18240,1)
def forward(self, xp, xc):
batch_size = xp.size(0)
x = torch.cat((xp.unsqueeze(1), xc), dim=1) #(10,6,20,456)
print(x.shape)
x = F.relu(self.conv1(x))
print(x.shape)
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(batch_size, -1)
print(x.shape)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x.squeeze()
为什么会出现这样的问题?要怎么修改?