eric-sjq 2023-10-04 22:43 采纳率: 66.7%
浏览 5
已结题

用pytorch训练对抗网络时矩阵乘法报错

报错内容:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (6144x64 and 12288x1)
Traceback (most recent call last):
  File "train2.py", line 80, in <module>
    real_output = discriminator(images)
  File "Python38\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "train2.py", line 56, in forward
    x=self.fc(x)
  File "Python38\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "Python38\site-packages\torch\nn\modules\linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (6144x64 and 12288x1)

部分代码:


class Generator(nn.Module):  
    def __init__(self, input_dim, output_dim):  
        super(Generator, self).__init__()
        #print(output_dim)
        self.fc = nn.Linear(input_dim, output_dim)  
    def forward(self, x):
        x = x.to(self.fc.weight.dtype) 
        x = self.fc(x)
        return x

class Discriminator(nn.Module):  
    def __init__(self, input_dim):  
        super(Discriminator, self).__init__()  
        self.fc = nn.Linear(input_dim,1)  
  
    def forward(self, x):
        x = x.to(self.fc.weight.dtype)
        x=self.fc(x)  
        return x
  
# 初始化数据集和数据加载器  
transform = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor()])  
dataset = ArtDataset('imgdata2/', transform=transform)  
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)  

# 初始化生成器和判别器
generator = Generator(10, 64*64*3)   
discriminator = Discriminator(64*64*3) # 64x64的RGB图像

  
# 定义损失函数和优化器  
criterion = nn.BCEWithLogitsLoss()  
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002)  
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)  


num_epochs=100
for epoch in range(num_epochs): # 这里只训练100个epoch,实际可能需要更多  
    for i, (images, descriptions) in enumerate(dataloader):  
        # 训练判别器  
        real_output = discriminator(images)  
        real_loss = criterion(real_output, torch.ones_like(real_output))  
        real_output = real_output.type(torch.int32)  # 添加这一行
        fake_output = discriminator(generator(descriptions))  
        fake_loss = criterion(fake_output, torch.zeros_like(fake_output))  
        fake_output = fake_output.type(torch.int32)  # 添加这一行
        d_loss = real_loss + fake_loss  
        optimizer_D.zero_grad()  
        d_loss.backward()  
        optimizer_D.step()
  
        # 训练生成器  
        z = torch.randn(images.shape[0], 10)#.to(device) # 生成随机噪声  
        fake_images = generator(z)  
        fake_output = discriminator(fake_images)  
        g_loss = criterion(fake_output, torch.ones_like(fake_output))  
        optimizer_G.zero_grad()  
        g_loss.backward()  
        optimizer_G.step()  
  
    print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')  

实在找不到哪有问题了,谁能帮忙看一下。

  • 写回答

2条回答 默认 最新

  • 专家-赤兔[在线] 优质创作者: 编程框架技术领域 2023-10-05 08:13
    关注

    引自强大而免费的微信小程序:皆我百晓生。


    在代码中出现的错误是由于矩阵乘法的问题。根据错误信息,mat1的shape是(6144, 64),mat2的shape是(12288, 1)。可见,mat1的列数和mat2的行数不匹配,导致无法进行矩阵乘法运算。

    要解决这个问题,可以检查以下几个地方:

    1. 检查输入的维度是否正确。Generator的input_dim应该与Discriminator的output_dim相匹配。
    2. 检查数据加载器中的图像大小。输入图像的维度应该与模型的输入维度匹配。
    3. 检查损失函数和优化器的设置。确保损失函数和优化器的输入维度匹配。

    另外,由于生成器的输出是一个图像,可以考虑使用合适的激活函数(例如sigmoid)来将输出映射到0到1之间的范围,以便与判别器的输出进行比较。

    希望这些提示能帮助您解决问题。如果问题仍然存在,请提供更多的代码和数据集的信息以便进一步排查问题。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 10月13日
  • 已采纳回答 10月5日
  • 创建了问题 10月4日