tjdnbj 2024-03-08 11:44 采纳率: 41.2%
浏览 4
已结题

深度学习循环神经网络提问

以下是我的代码,想问一下为什么会运行后会出现RuntimeError: expected scalar type Double but found Float的错误呢?该如何修改呢?

import time
import math
import zipfile

import numpy as np
import torch
from torch import nn,optim
import torch.nn.functional as F

import sys
sys.path.append("C:/Users/zyx20/Desktop/深度学习编程/pythonProject")
import d2lzh_pytorch as d2l
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

with zipfile.ZipFile('C:/Users/zyx20/Desktop/深度学习编程/data20201205-master/Data20201205/jaychou_lyrics.txt.zip') as zin:
    with zin.open('jaychou_lyrics.txt') as f:
        corpus_chars=f.read().decode('utf-8')
corpus_chars=corpus_chars.replace('\n','').replace('\r','')
idx_to_char=list(set(corpus_chars))
char_to_idx=dict([(char,i) for i,char in enumerate(idx_to_char)])
vocab_size=len(char_to_idx)
corpus_indices=[char_to_idx[char] for char in corpus_chars]

def one_hot(x,n_class,dtype=torch.float32):
    x=x.long()
    res=torch.zeros(x.shape[0],n_class,dtype=dtype,device=x.device)
    res.scatter_(1,x.view(-1,1),1)
    return res

x=torch.tensor([0,2])
one_hot(x,vocab_size)

def to_oneshot(x,n_class):
    return [one_hot(x[:,i],n_class) for i in range(x.shape[1])]

num_inputs,num_hiddens,num_outputs=vocab_size,256,vocab_size

def get_params():
    def _one(shape):
        ts=torch.tensor(np.random.normal(0,0.1,size=shape),device=device,dtype=torch.float32)
        return torch.nn.Parameter(ts,requires_grad=True)
    #隐藏层参数
    w_xh=_one((num_inputs,num_hiddens))
    w_hh = _one((num_hiddens, num_hiddens))
    b_h=torch.nn.Parameter(torch.zeros(num_hiddens,device=device,requires_grad=True))
    #输出层参数
    w_hq=_one((num_hiddens,num_outputs))
    b_q=torch.nn.Parameter(torch.zeros(num_outputs,device=device,requires_grad=True))
    return nn.ParameterList([w_xh,w_hh,b_h,w_hq,b_q])
#返回初始化隐藏状态
def init_rnn_state(batch_size,bum_hiddens,device):
    return (torch.zeros(batch_size,num_hiddens).to(device),)
#定义在一个时间步里如何计算隐藏状态和输出,RNN函数
def rnn(inputs,state,params):
    w_xh, w_hh, b_h, w_hq, b_q=params
    h,=state
    outputs=[]
    for x in inputs:
        h=torch.tanh(torch.matmul(x,w_xh)+torch.matmul(h,w_hh)+b_h)
        y=torch.matmul(h,w_hq)+b_q
        outputs.append(y)
    return outputs,(h,)
#定义预测函数
def predict_rnn(prefix,num_chars,rnn,params,init_rnn_state,num_hiddens,vocab_size,device,idx_to_char,char_to_idx):
    state=init_rnn_state(1,num_hiddens,device)
    output=[char_to_idx[prefix[0]]]
    for t in range(num_chars+len(prefix)-1):
        x=to_oneshot(torch.tensor([[output[-1]]],device=device),vocab_size)
        (y,state)=rnn(x,state,params)
        if t<len(prefix)-1:
            output.append(char_to_idx[prefix[t+1]])
        else:
            output.append(int(y[0].argmax(dim=1).item()))
    return ''.join([idx_to_char[i] for i in output])
#裁剪梯度
def grad_clipping(params,theta,device):
    norm=torch.tensor([0.0],device=device)
    for param in params:
        norm+=(param.grad.data**2).sum()
    norm=norm.sqrt().item()
    if norm>theta:
        for param in params:
            param.grad.data *=(theta/norm)
#定义模型训练函数
def train_and_predict_rnn(rnn,get_params,init_rnn_state,num_hiddens,vocab_size,device,corpus_indices,
                          idx_to_char,char_to_idx,is_random_iter,num_epochs,num_steps,lr,clipping_theta,
                          batch_size,pred_period,pred_len,prefixes):
    if is_random_iter:
        data_iter_fn=d2l.data_iter_random
    else:
        data_iter_fn =d2l.data_iter_consecutive
    params=get_params()
    loss=nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        if not is_random_iter:
            state=init_rnn_state(batch_size,num_hiddens,device)
        l_sum,n,start=0.0,0,time.time()
        data_iter=data_iter_fn(corpus_indices,batch_size,num_steps,device)
        for x,y in data_iter:
            if is_random_iter:
                state=init_rnn_state(batch_size,num_hiddens,device)
            else:
                for s in state:
                    s.detach_()

            inputs=to_oneshot(x,vocab_size)
            (output,state)=rnn(inputs,state,params)
            outputs=torch.cat(outputs,dim=0)
            y=torch.transpose(y,0,1).contiguous().view(-1)
            l=loss(outputs,y.long())

            #梯度清零
            if params[0].grad is not None:
                for param in params:
                    param.grad.data_zero()
            l.backward()
            grad_clipping(params,clipping_theta,device)
            d2l.sgd(params,lr,1)
            l_sum+=l.item()*y.shape[0]
            n+=y.shape[0]

        if (epoch+1)%pred_period==0:
            print('epoch %d,perplexity %f,time %.2f sec'%(epoch+1,math.exp(l_sum/n),time.time()-start))
            for prefix in prefixes:
                print(' -',predict_rnn(prefix,pred_len,rnn,params,init_rnn_state,num_hiddens,vocab_size,device,idx_to_char,char_to_idx))

num_epochs,num_steps,batch_size,lr,clipping_theta=250,35,32,1e2,1e2-1
pred_period,pred_len,prefixes=50,50,['分开','不分开']
train_and_predict_rnn(rnn,get_params,init_rnn_state,num_hiddens,vocab_size,device,corpus_indices,
                          idx_to_char,char_to_idx,True,num_epochs,num_steps,lr,clipping_theta,
                          batch_size,pred_period,pred_len,prefixes)

  • 写回答

4条回答 默认 最新

  • 百锦再@新空间代码工作室 全栈领域优质创作者 2024-03-08 11:44
    关注

    在PyTorch中,RuntimeError: expected scalar type Double but found Float的错误通常是因为数据类型不匹配导致的。在您的代码中,可能是因为在将数据传递到模型中时数据类型不匹配导致的。

    您可以尝试在将数据传递给模型之前将数据类型转换为正确的类型。例如,在以下代码片段中:

    def one_hot(x, n_class, dtype=torch.float32):x = x.long()res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)res.scatter_(1, x.view(-1, 1), 1)return res 
    x = torch.tensor([0, 2])one_hot(x, vocab_size)

    可以尝试将数据类型转换为正确的类型,例如将dtype=torch.float32更改为dtype=torch.float64,或者确保模型接收的数据类型与模型参数的数据类型匹配。

    另外,您还可以在模型定义部分中检查数据类型是否匹配,确保模型参数和输入数据的数据类型一致。

    通过这些调整,您可以避免RuntimeError: expected scalar type Double but found Float的错误。希望对您有帮助!如果您需要进一步的帮助,请告诉我。


    有问题你别着急,评论留言都可以,看到马上就回复,尽量及时补充齐
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(3条)

报告相同问题?

问题事件

  • 系统已结题 3月16日
  • 已采纳回答 3月8日
  • 创建了问题 3月8日

悬赏问题

  • ¥15 vscode编译ros找不到头文件,cmake.list文件出问题,如何解决?(语言-c++|操作系统-linux)
  • ¥15 通过AT指令控制esp8266发送信息
  • ¥15 有哪些AI工具提供可以通过代码上传EXCEL文件的API接口,并反馈分析结果
  • ¥15 二维装箱算法、矩形排列算法(相关搜索:二维装箱)
  • ¥20 nrf2401上电之后执行特定任务概率性一直处于最大重发状态
  • ¥15 二分图中俩集合中节点数与连边概率的关系
  • ¥20 wordpress如何限制ip访问频率
  • ¥15 自研小游戏,需要后台服务器存储用户数据关卡配置等数据
  • ¥15 请求解答odoo17外发加工某工序的实操方法
  • ¥20 IDEA ssm项目 跳转页面报错500