旅者心 2020-02-01 21:24 采纳率: 0%
浏览 7108

pytorch expected Tensor as element 0 in argument 0, but got tuple

做风格迁移学习时,在 target_feature=model(style_img).clone()发生错误,expected Tensor as element 0 in argument 0, but got tuple
然而style_img的格式是torch.cuda.FloatTensor

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim

from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
import copy

use_cuda=torch.cuda.is_available()
dtype=torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

style='images/timg.jpg'
content='images/123.jpg'

style_weight=1000
content_weight=1

imsize=128

loader=transforms.Compose([
transforms.Resize(imsize),
transforms.ToTensor()
])

def image_loader(image_name):
image=Image.open(image_name)
image=Variable(loader(image))
image=image.unsqueeze(0)
return image
style_img=image_loader(style).type(dtype)
content_img=image_loader(content).type(dtype)

print(style_img)
assert style_img.size()==content_img.size()

def imshow(tensor,title=None):
image=tensor.clone().cpu()
image=image.view(3,imsize,imsize)
image=unloader(image)
plt.imshow(image)
if title is not None:
plt.title(title)
plt.pause(0.001)
unloader=transforms.ToPILImage()

plt.ion()

#

plt.figure()

imshow(style_img.data,title='Style Image')

#

plt.figure()

imshow(content_img.data,title='Content Image')

加载网络

cnn=models.vgg19(pretrained=True).features
if use_cuda:
cnn=cnn.cuda()

class ContentLoss(nn.Module):
def init(self,target,weight):
super(ContentLoss,self).__init__()
self.target=target.detach()*weight
self.weight=weight
self.criterion=nn.MSELoss()
def forward(self, *input):
self.loss=self.criterion(input*self.weight,self.target)
self.output=input
return self.output
def backward(self,retain_graph=True):
self.loss.backward(retain_graph=retain_graph)
return self.loss

class StyleLoss(nn.Module):
def init(self,target,weight):
super(StyleLoss,self).__init__()
self.target=target.detach()*weight
self.weight=weight
#self.gram=GramMatrix()
self.criterion=nn.MSELoss()
def forward(self,input):
self.output=input.clone()
input=input.cuda()
self_G=Gram(input)
self_G.mul(self.weight)
self.loss=self.criterion(self_G,self.target)
return self.output
def backward(self,retain_graph=True):
self.loss.backward(retain_graph=retain_graph)
return self.loss

def Gram(input):
a,b,c,d=input.size()
features=input.view(a*b,c*d)
G=torch.mm(features,features.t())
return G.div(a*b*c*d)

content_layers=['conv_4']
style_layers=['conv_1','conv_2','conv_3','conv_4','conv_5']

content_losses=[]
style_losses=[]

model=nn.Sequential()

if use_cuda:
model=model.cuda()

i=1
for layer in list(cnn):
if isinstance(layer,nn.Conv2d):
name='conv_'+str(i)
model.add_module(name,layer)

     if name in content_layers:
         target=model(content_img).clone()
         content_loss=ContentLoss(target,content_weight)
         content_loss=content_loss.cuda() if use_cuda else content_loss
         model.add_module('content_loss'+str(i),content_loss)
         content_losses.append(content_loss)

     if name in style_layers:
         target_feature=model(style_img).clone()
         target_feature=target_feature.cuda() if use_cuda else target_feature
         target_feature_gram=Gram(target_feature)
         style_loss=StyleLoss(target_feature_gram,style_weight)
         style_loss=style_loss.cuda() if use_cuda else style_loss
         model.add_module("style_loss"+str(i),style_loss)
         style_losses.append(style_loss)

 if isinstance(layer,nn.ReLU):
     name='relu'+str(i)
     model.add_module(name,layer)
     i+=1
 if isinstance(layer,nn.MaxPool2d):
     name="pool"+str(i)
     model.add_module(name,layer)

input_img=Variable(torch.randn(content_img.data.size()))
if use_cuda:
input_img=input_img.cuda()
content_img=content_img.cuda()
style_img=style_img.cuda()
plt.figure()
imshow(input_img.data,title='Input Image')

  • 写回答

1条回答 默认 最新

  • 关注
    评论

报告相同问题?

悬赏问题

  • ¥15 请教一下各位,为什么我这个没有实现模拟点击
  • ¥15 执行 virtuoso 命令后,界面没有,cadence 启动不起来
  • ¥50 comfyui下连接animatediff节点生成视频质量非常差的原因
  • ¥20 有关区间dp的问题求解
  • ¥15 多电路系统共用电源的串扰问题
  • ¥15 slam rangenet++配置
  • ¥15 有没有研究水声通信方面的帮我改俩matlab代码
  • ¥15 ubuntu子系统密码忘记
  • ¥15 保护模式-系统加载-段寄存器
  • ¥15 电脑桌面设定一个区域禁止鼠标操作