weixin_51793354 2022-10-24 14:12 采纳率: 54.5%
浏览 91
已结题

pytorch 的dataloader使用一个epoch后需要重新“装载”么?

pytorch 的dataloader使用一个epoch后需要重新“装载”么?
我这样的代码有问题么

    data_train = Weibo_Dataset("./data/train/", train=True)
    train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
    for epoch in range(epochs):
        train_loop(train_loader, model, loss_fn, optimizer)
    print("Done!")

需要像下面这样么?

    data_train = Weibo_Dataset("./data/train/", train=True)
    for epoch in range(epochs):
        train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
        train_loop(train_loader, model, loss_fn, optimizer)
    print("Done!")

下面是train_loop函数的定义

    def train_loop(dataloader, model, loss_fn, optimizer):
        model.train()
        size = len(dataloader.dataset)
        for batch, (X, y) in enumerate(dataloader):
            X = get_feature_vector(X, time_step)
            X = X.to(device)
            y = y.to(device)
            # Compute prediction and loss
            pred = model(X)
            loss = loss_fn(pred, y)
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if batch % 1 == 0:
                loss, current = loss.item(), batch * len(X)
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            pass

  • 写回答

2条回答 默认 最新

  • ·星辰大海 2022-10-24 15:00
    关注

    你的数据很多吗? 一般情况下是不需要的,DataLoader本身是个可迭代的东西,如果设计得当的话大部分情况是不需要多次加载的。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 10月24日
  • 已采纳回答 10月24日
  • 修改了问题 10月24日
  • 赞助了问题酬金15元 10月24日
  • 展开全部

悬赏问题

  • ¥15 华为ensp模拟器中S5700交换机在配置过程中老是反复重启
  • ¥15 uniapp uview http 如何实现统一的请求异常信息提示?
  • ¥15 有了解d3和topogram.js库的吗?有偿请教
  • ¥100 任意维数的K均值聚类
  • ¥15 stamps做sbas-insar,时序沉降图怎么画
  • ¥15 买了个传感器,根据商家发的代码和步骤使用但是代码报错了不会改,有没有人可以看看
  • ¥15 关于#Java#的问题,如何解决?
  • ¥15 加热介质是液体,换热器壳侧导热系数和总的导热系数怎么算
  • ¥100 嵌入式系统基于PIC16F882和热敏电阻的数字温度计
  • ¥20 BAPI_PR_CHANGE how to add account assignment information for service line