跑代码的时候,这段代码遇见了报错:
def train_emb(self, images, captions, lengths, image_lengths=None, warmup_alpha=None):
"""One training step given images and captions."""
self.Eiters += 1
self.logger.update('Eit', self.Eiters)
self.logger.update('lr', self.optimizer.param_groups[0]['lr'])
captions_all = captions.reshape(captions.size(0) * captions.size(1), captions.size(2))
caption_lens = lengths.reshape(-1)
# compute the embeddings #(256, 1024)
img_emb, cap_emb = self.forward_emb(images, captions_all, caption_lens, image_lengths=image_lengths)
# measure accuracy and record loss
self.optimizer.zero_grad()
loss = self.forward_loss(img_emb, cap_emb)
if warmup_alpha is not None:
loss = loss * warmup_alpha
# compute gradient and update
loss.backward(retain_graph=True)
# Adversarial Training
img_real = img_emb.detach() # img_real(256,1024)
cap_real = cap_emb.detach() # cap_real(256,1024)
# Generate fake embeddings
img_fake = self.img_gen(cap_emb).detach() # cap_emb(256,1024), img_fake(256,2048)
cap_fake = self.txt_gen(img_emb).detach() #img_emb(256,1024), cap_fake(256,1024)
# Train discriminators
img_real.requires_grad = True
img_fake.requires_grad = True
cap_real.requires_grad = True
cap_fake.requires_grad = True
disc_img_real = self.img_disc(img_real) # img_real(256,1024) disc_img_real()
disc_img_fake = self.img_disc(img_fake) #img_fake(256,1024) disc_img_fake()
disc_cap_real = self.txt_disc(cap_real) #cap_real(256,1024) dis_cap_real()
disc_cap_fake = self.txt_disc(cap_fake) #cap_fake(256,1024) dis_cap_fake()
disc_loss_img = self.gan_criterion(disc_img_real, True) + self.gan_criterion(disc_img_fake, False)
disc_loss_cap = self.gan_criterion(disc_cap_real, True) + self.gan_criterion(disc_cap_fake, False)
total_disc_loss = disc_loss_img + disc_loss_cap
total_disc_loss.backward(retain_graph=True)
clip_grad_norm_(self.params, self.grad_clip)
self.optimizer.step()
# Train generators
#self.gen_optim.zero_grad()
img_fake.requires_grad = False
cap_fake.requires_grad = False
self.optimizer.zero_grad() # Clear gradients for generator training
gen_img = self.img_gen(cap_emb) #gen_img(256,1024)
gen_cap = self.txt_gen(img_emb)#gen_cap(256,1024)
disc_img_fake_for_gen = self.img_disc(gen_img) #disc_img_fake_for_gen(256,1)
disc_cap_fake_for_gen = self.txt_disc(gen_cap) #disc_cap_fake_for_gen(256,1)
gen_loss_img = self.gan_criterion(disc_img_fake_for_gen, True)
gen_loss_cap = self.gan_criterion(disc_cap_fake_for_gen, True)
total_gen_loss = gen_loss_img + gen_loss_cap
total_gen_loss.backward()
clip_grad_norm_(self.params, self.grad_clip)
self.optimizer.step()
debug调试发现报错在
total_gen_loss.backward()
报错内容是:
Traceback (most recent call last):
File "train.py", line 274, in <module>
main()
File "train.py", line 99, in main
train(opt, train_loader, model, epoch, val_loader)
File "train.py", line 155, in train
model.train_emb(images, captions, lengths, image_lengths=img_lengths)
File "/home/s1/ESA-main4/ESA_BERT/lib/vse.py", line 290, in train_emb
total_gen_loss.backward()
File "/home/s1/anaconda3/envs/s1_new/lib/python3.8/site-packages/torch/_tensor.py", line 525, in backward
torch.autograd.backward(
File "/home/s1/anaconda3/envs/s1_new/lib/python3.8/site-packages/torch/autograd/__init__.py", line 267, in backward
_engine_run_backward(
File "/home/s1/anaconda3/envs/s1_new/lib/python3.8/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Output 0 of TBackward0 is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.
请问怎么解决这个报错啊?