chunk_size=64
net.train()
for epoch in tqdm(range(N_EPOCHS)):
running_loss=0.0
gc.collect()
torch.cuda.empty_cache()
totalSampleNum=0
for song in range(n):
print("training with song",song+1)
wave, sr = librosa.load(str(song+1)+'.wav', sr=None)
logstft = torch.log10(torch.abs(spectrogram(torch.from_numpy(wave).to(device)))+1)
logstft = torch.t(logstft)
totalSampleNum += len(logstft)
print(logstft.shape)
#normalize
maxs = logstft.max(dim=1).values
logstft /= maxs.reshape([len(logstft),1])
reconstuct=logstft.t()[698:1025].t().unsqueeze(1)
#train
mainH1, mainH2, mainH3, mainH4, overtoneH1, overtoneH2, overtoneH3, overtoneH4, finalH = torch.zeros(4,1,220).to(device), torch.zeros(4,1,210).to(device), torch.zeros(4,1,200).to(device), torch.zeros(4,1,200).to(device), torch.zeros(4,1,768).to(device), torch.zeros(4,1,512).to(device), torch.zeros(4,1,448).to(device), torch.zeros(4,1,384).to(device), torch.zeros(3,1,500).to(device)
for chunk in range(0,len(logstft),chunk_size):
mainH1.detach_()
mainH2.detach_()
mainH4.detach_()
mainH4.detach_()
overtoneH1.detach_()
overtoneH2.detach_()
overtoneH3.detach_()
overtoneH4.detach_()
finalH.detach_()
optimizer.zero_grad()
prediction, mainH1, mainH2, mainH3, mainH4, overtoneH1, overtoneH2, overtoneH3, overtoneH4, finalH = net(logstft[chunk:chunk+chunk_size].unsqueeze(1), mainH1, mainH2, mainH3, mainH4, overtoneH1, overtoneH2, overtoneH3, overtoneH4, finalH)
loss = loss_func(prediction, reconstuct[chunk:chunk+chunk_size])/chunk_size
loss.backward()
running_loss+=loss.detach().cpu().item()
optimizer.step()
#plt.figure()
#librosa.display.specshow(np.concatenate((logstft.t().cpu().numpy()[0:698],prediction.detach().squeeze().t().cpu().numpy()), axis = 0),sr=sr)
#plt.show()
running_loss/=totalSampleNum
print(" | epoch:"+str(epoch+1+start_epoch)+" | total loss:"+str(running_loss)+" | log loss:"+str(np.log10(running_loss)))
err.append(running_loss)
gc.collect()
torch.cuda.empty_cache()
if((epoch+1+start_epoch)%1==0):
print("saving")
PATH = './net_reshape_ep'+str(epoch+1+start_epoch)+'.pth'
torch.save(net.state_dict(), PATH)
print("saved. continue training...")
已经加上了.detach_(),还是会报错。
试了loss.backward(retain_graph=True),也不行