CTCLoss不下降怎么破?能否帮助一下我,真心请教,卡了好几天了
计算CTCLoss的代码如下:
data = next(train_iter)
cpu_images, cpu_texts = data
batch_size = cpu_images.size(0)
utils.loadData(image, cpu_images)
t, l = converter.encode(cpu_texts)
utils.loadData(text, t)
utils.loadData(length, l)
preds = crnn(image)
preds = preds.log_softmax(2)
preds_size = torch.IntTensor([preds.size(0)] * batch_size)
# print(preds.shape) # torch.Size([76, 64, 6464])
# print(text.shape) # torch.Size([320])
# print(preds_size.shape) # torch.Size([64])
# print(l.shape) # torch.Size([64])
# exit()
loss = criterion(preds, text, preds_size, length
训练的loss如下: