OwenLiu66 2021-05-18 11:21 采纳率: 0%
浏览 51

pytorch RNN出现RuntimeError: Trying to backward thr

chunk_size=64
net.train()
for epoch in tqdm(range(N_EPOCHS)):
    running_loss=0.0
    gc.collect()
    torch.cuda.empty_cache()
    totalSampleNum=0
    for song in range(n):
        print("training with song",song+1)
        wave, sr = librosa.load(str(song+1)+'.wav', sr=None)
        logstft = torch.log10(torch.abs(spectrogram(torch.from_numpy(wave).to(device)))+1)
        logstft = torch.t(logstft)
        totalSampleNum += len(logstft)
        print(logstft.shape)
        #normalize
        maxs = logstft.max(dim=1).values
        logstft /= maxs.reshape([len(logstft),1])
        reconstuct=logstft.t()[698:1025].t().unsqueeze(1)
        #train
        mainH1, mainH2, mainH3, mainH4, overtoneH1, overtoneH2, overtoneH3, overtoneH4, finalH = torch.zeros(4,1,220).to(device), torch.zeros(4,1,210).to(device), torch.zeros(4,1,200).to(device), torch.zeros(4,1,200).to(device), torch.zeros(4,1,768).to(device), torch.zeros(4,1,512).to(device), torch.zeros(4,1,448).to(device), torch.zeros(4,1,384).to(device), torch.zeros(3,1,500).to(device)
        for chunk in range(0,len(logstft),chunk_size):
            mainH1.detach_()
            mainH2.detach_()
            mainH4.detach_()
            mainH4.detach_()
            overtoneH1.detach_()
            overtoneH2.detach_()
            overtoneH3.detach_()
            overtoneH4.detach_()
            finalH.detach_()
            optimizer.zero_grad()
            prediction, mainH1, mainH2, mainH3, mainH4, overtoneH1, overtoneH2, overtoneH3, overtoneH4, finalH = net(logstft[chunk:chunk+chunk_size].unsqueeze(1), mainH1, mainH2, mainH3, mainH4, overtoneH1, overtoneH2, overtoneH3, overtoneH4, finalH)
            loss = loss_func(prediction, reconstuct[chunk:chunk+chunk_size])/chunk_size
            loss.backward()
            running_loss+=loss.detach().cpu().item()
            optimizer.step()
        #plt.figure()
        #librosa.display.specshow(np.concatenate((logstft.t().cpu().numpy()[0:698],prediction.detach().squeeze().t().cpu().numpy()), axis = 0),sr=sr)
        #plt.show()
    running_loss/=totalSampleNum
    print(" | epoch:"+str(epoch+1+start_epoch)+" | total loss:"+str(running_loss)+" | log loss:"+str(np.log10(running_loss)))
    err.append(running_loss)
    gc.collect()
    torch.cuda.empty_cache()
    if((epoch+1+start_epoch)%1==0):
        print("saving")
        PATH = './net_reshape_ep'+str(epoch+1+start_epoch)+'.pth'
        torch.save(net.state_dict(), PATH)
        print("saved. continue training...")

已经加上了.detach_(),还是会报错。

试了loss.backward(retain_graph=True),也不行

  • 写回答

1条回答 默认 最新

  • OwenLiu66 2021-05-18 18:53
    关注

    大家不用答了,我发现是24行把3写成了4

    评论

报告相同问题?

悬赏问题

  • ¥30 python代码,帮调试
  • ¥15 #MATLAB仿真#车辆换道路径规划
  • ¥15 java 操作 elasticsearch 8.1 实现 索引的重建
  • ¥15 数据可视化Python
  • ¥15 要给毕业设计添加扫码登录的功能!!有偿
  • ¥15 kafka 分区副本增加会导致消息丢失或者不可用吗?
  • ¥15 微信公众号自制会员卡没有收款渠道啊
  • ¥100 Jenkins自动化部署—悬赏100元
  • ¥15 关于#python#的问题:求帮写python代码
  • ¥20 MATLAB画图图形出现上下震荡的线条