pytorch 的dataloader使用一个epoch后需要重新“装载”么?
我这样的代码有问题么
data_train = Weibo_Dataset("./data/train/", train=True)
train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
for epoch in range(epochs):
train_loop(train_loader, model, loss_fn, optimizer)
print("Done!")
需要像下面这样么?
data_train = Weibo_Dataset("./data/train/", train=True)
for epoch in range(epochs):
train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
train_loop(train_loader, model, loss_fn, optimizer)
print("Done!")
下面是train_loop函数的定义
def train_loop(dataloader, model, loss_fn, optimizer):
model.train()
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
X = get_feature_vector(X, time_step)
X = X.to(device)
y = y.to(device)
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 1 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
pass