OwenLiu66
2021-05-18 11:21
采纳率: 0%
浏览 16

pytorch RNN出现RuntimeError: Trying to backward thr

chunk_size=64
net.train()
for epoch in tqdm(range(N_EPOCHS)):
    running_loss=0.0
    gc.collect()
    torch.cuda.empty_cache()
    totalSampleNum=0
    for song in range(n):
        print("training with song",song+1)
        wave, sr = librosa.load(str(song+1)+'.wav', sr=None)
        logstft = torch.log10(torch.abs(spectrogram(torch.from_numpy(wave).to(device)))+1)
        logstft = torch.t(logstft)
        totalSampleNum += len(logstft)
        print(logstft.shape)
        #normalize
        maxs = logstft.max(dim=1).values
        logstft /= maxs.reshape([len(logstft),1])
        reconstuct=logstft.t()[698:1025].t().unsqueeze(1)
        #train
        mainH1, mainH2, mainH3, mainH4, overtoneH1, overtoneH2, overtoneH3, overtoneH4, finalH = torch.zeros(4,1,220).to(device), torch.zeros(4,1,210).to(device), torch.zeros(4,1,200).to(device), torch.zeros(4,1,200).to(device), torch.zeros(4,1,768).to(device), torch.zeros(4,1,512).to(device), torch.zeros(4,1,448).to(device), torch.zeros(4,1,384).to(device), torch.zeros(3,1,500).to(device)
        for chunk in range(0,len(logstft),chunk_size):
            mainH1.detach_()
            mainH2.detach_()
            mainH4.detach_()
            mainH4.detach_()
            overtoneH1.detach_()
            overtoneH2.detach_()
            overtoneH3.detach_()
            overtoneH4.detach_()
            finalH.detach_()
            optimizer.zero_grad()
            prediction, mainH1, mainH2, mainH3, mainH4, overtoneH1, overtoneH2, overtoneH3, overtoneH4, finalH = net(logstft[chunk:chunk+chunk_size].unsqueeze(1), mainH1, mainH2, mainH3, mainH4, overtoneH1, overtoneH2, overtoneH3, overtoneH4, finalH)
            loss = loss_func(prediction, reconstuct[chunk:chunk+chunk_size])/chunk_size
            loss.backward()
            running_loss+=loss.detach().cpu().item()
            optimizer.step()
        #plt.figure()
        #librosa.display.specshow(np.concatenate((logstft.t().cpu().numpy()[0:698],prediction.detach().squeeze().t().cpu().numpy()), axis = 0),sr=sr)
        #plt.show()
    running_loss/=totalSampleNum
    print(" | epoch:"+str(epoch+1+start_epoch)+" | total loss:"+str(running_loss)+" | log loss:"+str(np.log10(running_loss)))
    err.append(running_loss)
    gc.collect()
    torch.cuda.empty_cache()
    if((epoch+1+start_epoch)%1==0):
        print("saving")
        PATH = './net_reshape_ep'+str(epoch+1+start_epoch)+'.pth'
        torch.save(net.state_dict(), PATH)
        print("saved. continue training...")

已经加上了.detach_(),还是会报错。

试了loss.backward(retain_graph=True),也不行

  • 写回答
  • 好问题 提建议
  • 追加酬金
  • 关注问题
  • 邀请回答

1条回答 默认 最新

  • OwenLiu66 2021-05-18 18:53

    大家不用答了,我发现是24行把3写成了4

    评论
    解决 无用
    打赏 举报

相关推荐 更多相似问题