Zhyan1212 2022-09-07 12:59 采纳率: 63.6%
浏览 45
已结题

pytorch自编码器训练

batchsize设置问题,设置 batch_size = 1,32,64 , 训练速度很慢 ,而且loss也很大,达到上千,
数据的shape为 torch.Size([5152, 1, 2000]) 设置为 batch_size = len(model.encoder(tain_data)) = 5152 时,
训练速度变快 loss也降低很大,但效果并不好 最好才到 loss = 2 左右.
全连接层出nn.Linear(1998, 256)),这里的256是可以人为随意设置么,还是需要与输入有对应关系?

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(1, 1, 3, 1, padding=0),
            nn.ReLU(),
            nn.Linear(1998, 256))

        self.decoder = nn.Sequential(
            nn.Linear(256, 1998),
            nn.ConvTranspose1d(1, 1, 3, 1, padding=0),
            nn.Sigmoid())
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
num_epochs = 100
dataloader = DataLoader(trian_data, batch_size=len(model.encoder(trian_data)), shuffle=True)
model = autoencoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01,
                             weight_decay=1e-5)
for epoch in range(num_epochs):
   total_loss = 0
   for data in dataloader:
        x = data
        x = Variable(x)
        # ===================forward=====================
        output = model(x)
        loss = criterion(output, x)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.data
     # ===================log========================
   print('epoch [{}/{}], loss:{:.4f}'
       .format(epoch+1, num_epochs, total_loss))

》》》epoch [95/100], loss:2.0771
           epoch [96/100], loss:2.0772
           epoch [97/100], loss:2.0769
           epoch [98/100], loss:2.0767
           epoch [99/100], loss:2.0769
           epoch [100/100], loss:2.0766
  • 写回答

3条回答 默认 最新

  • 万里鹏程转瞬至 人工智能领域优质创作者 2022-09-07 14:21
    关注

    batch_size一般为32、64就可以了,不需要用到全部的数据,这样子会导致模型收敛慢。
    256只是编码器的输出位数,没有特定约束,跟1998没有任何关系,可以是任意数,只需要保证跟解码器的输入是一样的就行了

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

问题事件

  • 系统已结题 9月16日
  • 已采纳回答 9月8日
  • 创建了问题 9月7日

悬赏问题

  • ¥15 2024-五一综合模拟赛
  • ¥15 下图接收小电路,谁知道原理
  • ¥15 装 pytorch 的时候出了好多问题,遇到这种情况怎么处理?
  • ¥20 IOS游览器某宝手机网页版自动立即购买JavaScript脚本
  • ¥15 手机接入宽带网线,如何释放宽带全部速度
  • ¥30 关于#r语言#的问题:如何对R语言中mfgarch包中构建的garch-midas模型进行样本内长期波动率预测和样本外长期波动率预测
  • ¥15 ETLCloud 处理json多层级问题
  • ¥15 matlab中使用gurobi时报错
  • ¥15 这个主板怎么能扩出一两个sata口
  • ¥15 不是,这到底错哪儿了😭