报错内容:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (6144x64 and 12288x1)
Traceback (most recent call last):
File "train2.py", line 80, in <module>
real_output = discriminator(images)
File "Python38\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "train2.py", line 56, in forward
x=self.fc(x)
File "Python38\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "Python38\site-packages\torch\nn\modules\linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (6144x64 and 12288x1)
部分代码:
class Generator(nn.Module):
def __init__(self, input_dim, output_dim):
super(Generator, self).__init__()
#print(output_dim)
self.fc = nn.Linear(input_dim, output_dim)
def forward(self, x):
x = x.to(self.fc.weight.dtype)
x = self.fc(x)
return x
class Discriminator(nn.Module):
def __init__(self, input_dim):
super(Discriminator, self).__init__()
self.fc = nn.Linear(input_dim,1)
def forward(self, x):
x = x.to(self.fc.weight.dtype)
x=self.fc(x)
return x
# 初始化数据集和数据加载器
transform = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor()])
dataset = ArtDataset('imgdata2/', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# 初始化生成器和判别器
generator = Generator(10, 64*64*3)
discriminator = Discriminator(64*64*3) # 64x64的RGB图像
# 定义损失函数和优化器
criterion = nn.BCEWithLogitsLoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
num_epochs=100
for epoch in range(num_epochs): # 这里只训练100个epoch,实际可能需要更多
for i, (images, descriptions) in enumerate(dataloader):
# 训练判别器
real_output = discriminator(images)
real_loss = criterion(real_output, torch.ones_like(real_output))
real_output = real_output.type(torch.int32) # 添加这一行
fake_output = discriminator(generator(descriptions))
fake_loss = criterion(fake_output, torch.zeros_like(fake_output))
fake_output = fake_output.type(torch.int32) # 添加这一行
d_loss = real_loss + fake_loss
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()
# 训练生成器
z = torch.randn(images.shape[0], 10)#.to(device) # 生成随机噪声
fake_images = generator(z)
fake_output = discriminator(fake_images)
g_loss = criterion(fake_output, torch.ones_like(fake_output))
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
实在找不到哪有问题了,谁能帮忙看一下。