Zhyan1212 2022-09-07 12:59 采纳率: 63.6%
浏览 45
已结题

pytorch自编码器训练

batchsize设置问题,设置 batch_size = 1,32,64 , 训练速度很慢 ,而且loss也很大,达到上千,
数据的shape为 torch.Size([5152, 1, 2000]) 设置为 batch_size = len(model.encoder(tain_data)) = 5152 时,
训练速度变快 loss也降低很大,但效果并不好 最好才到 loss = 2 左右.
全连接层出nn.Linear(1998, 256)),这里的256是可以人为随意设置么,还是需要与输入有对应关系?

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(1, 1, 3, 1, padding=0),
            nn.ReLU(),
            nn.Linear(1998, 256))

        self.decoder = nn.Sequential(
            nn.Linear(256, 1998),
            nn.ConvTranspose1d(1, 1, 3, 1, padding=0),
            nn.Sigmoid())
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
num_epochs = 100
dataloader = DataLoader(trian_data, batch_size=len(model.encoder(trian_data)), shuffle=True)
model = autoencoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01,
                             weight_decay=1e-5)
for epoch in range(num_epochs):
   total_loss = 0
   for data in dataloader:
        x = data
        x = Variable(x)
        # ===================forward=====================
        output = model(x)
        loss = criterion(output, x)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.data
     # ===================log========================
   print('epoch [{}/{}], loss:{:.4f}'
       .format(epoch+1, num_epochs, total_loss))

》》》epoch [95/100], loss:2.0771
           epoch [96/100], loss:2.0772
           epoch [97/100], loss:2.0769
           epoch [98/100], loss:2.0767
           epoch [99/100], loss:2.0769
           epoch [100/100], loss:2.0766
  • 写回答

3条回答 默认 最新

  • 万里鹏程转瞬至 人工智能领域优质创作者 2022-09-07 14:21
    关注

    batch_size一般为32、64就可以了,不需要用到全部的数据,这样子会导致模型收敛慢。
    256只是编码器的输出位数,没有特定约束,跟1998没有任何关系,可以是任意数,只需要保证跟解码器的输入是一样的就行了

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

问题事件

  • 系统已结题 9月16日
  • 已采纳回答 9月8日
  • 创建了问题 9月7日

悬赏问题

  • ¥60 求一个简单的网页(标签-安全|关键词-上传)
  • ¥35 lstm时间序列共享单车预测,loss值优化,参数优化算法
  • ¥15 基于卷积神经网络的声纹识别
  • ¥15 Python中的request,如何使用ssr节点,通过代理requests网页。本人在泰国,需要用大陆ip才能玩网页游戏,合法合规。
  • ¥100 为什么这个恒流源电路不能恒流?
  • ¥15 有偿求跨组件数据流路径图
  • ¥15 写一个方法checkPerson,入参实体类Person,出参布尔值
  • ¥15 我想咨询一下路面纹理三维点云数据处理的一些问题,上传的坐标文件里是怎么对无序点进行编号的,以及xy坐标在处理的时候是进行整体模型分片处理的吗
  • ¥15 一直显示正在等待HID—ISP
  • ¥15 Python turtle 画图