m0_70245668 2022-09-17 16:46 采纳率: 0%
浏览 81
已结题

基于MATLAB用lstm手写代码预测数据

基于MATLAB用LSTM源代码预测数据,程序能正常运行,结果却趋近于一条直线,不知道哪里出现的问题,下面是我写的代码,请给看一下。有偿
%%
clear
clc
%% 数据导入
load('G4_4ring.mat')
in = 10;%输入样本的数据量
out = 5;%输出样本的数据量
dataX = G4_4ring(1:200,:);
dataTrain =(dataX-mean(dataX))/std(dataX);
for i = 1 : 185
data_1(:,i) = dataTrain(i:i+9,1);
end
for i = 1 : 185
data_2(:,i) = dataTrain(i+10:i+14,1);
end

datatrain_in = data_1;%训练集输入
datatrain_out = data_2;%训练集输出
%% 设置BP参数

net.in = in;%输入节点
net.hid = 5;%隐藏节点
net.out = out;%输出节点

%权重偏置初始化
wf = 2*(rand(net.hid , net.hid+net.in) - 1/2);
bf = 2*(rand(net.hid,1) - 1/2);

wi = 2*(rand(net.hid , net.hid+net.in) - 1/2);
bi = 2*(rand(net.hid , 1) - 1/2);

wc = 2*(rand(net.hid ,net.hid+net.in) - 1/2);
bc = 2*(rand(net.hid , 1) - 1/2);

wo = 2*(rand(net.hid , net.hid+net.in) - 1/2);
bo = 2*(rand(net.hid , 1) - 1/2);

wy = 2*(rand(net.out , net.hid) - 1/2);
by = 2*(rand(net.out , 1) - 1/2);
%训练参数
learningrate = 0.001;
maxiter = 2;
iteration = 0;

%% 开始训练
for i = 1:maxiter
%% 前向计算

for t = 1:10
    if t==1
       Zx1 = wf*[zeros(out,1);datatrain_in(:,1)] + bf;%调整ht的维度
       Zx2 = wi*[zeros(out,1);datatrain_in(:,1)]+bi;
       Zx3 = wc*[zeros(out,1);datatrain_in(:,1)]+bc;
       Zx4 = wo*[zeros(out,1);datatrain_in(:,1)]+bo;
       for j=1:5

           f_t(j,1) = Sigmoid(Zx1(j,:));%遗忘门

           i_t(j,1) = Sigmoid(Zx2(j,:));

           c_t_current(j,1) = tanh(Zx3(j,:));%输入门

           o_t(j,1) = Sigmoid(Zx4(j,:));

       end         

       c_t(:,1) = i_t(:,1).*c_t_current(:,1);%更新细胞状态
       h_t(:,1) = o_t(:,1).*Tanh(c_t(:,1));%输出门
       z_t = wy*h_t(:,1) + by;
       y_t(:,1) = Sigmoid(z_t);
    
   else
       Zx1 = wf*[h_t(:,t-1);datatrain_in(:,t)]+bf;
       Zx2 = wi*[h_t(:,t-1);datatrain_in(:,t)]+bi;
       Zx3 = wc*[h_t(:,t-1);datatrain_in(:,t)]+bc;
       Zx4 = wo*[h_t(:,t-1);datatrain_in(:,t)]+bo;
       
        f_t(:,t) = Sigmoid(Zx1);%遗忘门
   
        i_t(:,t) = Sigmoid(Zx2);
  
        c_t_current(:,t) = Tanh(Zx3);%输入门

        o_t(:,t) = Sigmoid(Zx4);

    

       c_t(:,t) = f_t(:,t).*c_t(:,t-1)+i_t(:,t).*c_t_current(:,t);%更新细胞状态
   
    
       h_t(:,t) = o_t(:,t).*Tanh(c_t(:,t));%输出门

       z_t(:,t) = wy*h_t(:,t) + by;
       y_t(:,t) = Sigmoid(z_t(:,t));
    end
    

end
 for   t=1:10
Dby(:,t) = y_t(:,t) - datatrain_out(:,t);
    sse_train = sumsqr(Dby);%误差平方和
 end
fprintf('第 %d 次迭代  误差:  %f\n' , i , sse_train);
iteration = iteration + 1;


%% 反向传播
for t = 10 : -1 : 2
    if t==10
       dby = y_t(:,t) - datatrain_out(:,t);
       dwy = (y_t(:,t) - datatrain_out(:,t))*h_t(:,t)';

       dh_t = wy'*dby;
       dc_t = dh_t.*o_t(:,t).*(1-(Tanh(c_t(:,t))).^2);
      
       dwo = dh_t.*Tanh(c_t(:,t)).*o_t(:,t).*(1-o_t(:,t))*[h_t(:,t-1);datatrain_in(:,t)]';
           

       dwi= dc_t.*c_t_current(:,t).*i_t(:,t).*(1-i_t(:,t))*[h_t(:,t-1);datatrain_in(:,t)]';
          

       dwf = dc_t.*c_t(:,t-1).*f_t(:,t).*(1-f_t(:,t))*[h_t(:,t-1);datatrain_in(:,t)]';
           

       dwc = dc_t.*i_t(:,t).*(1-(c_t_current(:,t)).^2)*[h_t(:,t-1);datatrain_in(:,t)]';
          


       dbo = dh_t.*Tanh(c_t(:,t)).*o_t(:,t).*(1-o_t(:,t));
       dbi = dc_t.*c_t_current(:,t).*i_t(:,t).*(1-i_t(:,t));
       dbf = dc_t.*c_t(:,t-1).*f_t(:,t).*(1-f_t(:,t));
       dbc = dc_t.*i_t(:,t).*(1-(c_t_current(:,t)).^2);
    %更新
    Wf = dwf;
    Wi = dwi;
    Wc = dwc;
    Wo = dwo;

    Bf = dbf;
    Bi = dbi;
    Bc = dbc;
    Bo = dbo;

    Wy = dwy;
    By = dby;  
      
    else
       
        dby = y_t(:,t) - datatrain_out(:,t);
        dwy =(y_t(:,t) - datatrain_out(:,t))*h_t(:,t)';

        dh_t = wy'*dby+ wo(:,1:5)'*(dh_t.*Tanh(c_t(:,t+1)).*o_t(:,t+1).*(1-o_t(:,t+1))) + wi(:,1:5)'*(dc_t.*c_t_current(:,t+1).*i_t(:,t+1).*(1-i_t(:,t+1))) + wf(:,1:5)'*(dc_t.*c_t(:,t).*f_t(:,t+1).*(1-f_t(:,t+1))) + wc(:,1:5)'*(dc_t.*i_t(:,t+1).*(1-(c_t_current(:,t+1)).^2));

        dc_t = dh_t.*o_t(:,t).*(1-(Tanh(c_t(:,t))).^2)+dc_t.*f_t(:,t+1);
 
        dwo = dh_t.*Tanh(c_t(:,t)).*o_t(:,t).*(1-o_t(:,t))*[h_t(:,t-1);datatrain_in(:,t)]';
       

        dwi = dc_t.*c_t_current(:,t).*i_t(:,t).*(1-i_t(:,t))*[h_t(:,t-1);datatrain_in(:,t)]';
      

        dwf = dc_t.*c_t(:,t-1).*f_t(:,t).*(1-f_t(:,t))*[h_t(:,t-1);datatrain_in(:,t)]';
       

        dwc = dc_t.*i_t(:,t).*(1-(c_t_current(:,t)).^2)*[h_t(:,t-1);datatrain_in(:,t)]';
      
    
        dbo = dh_t.*Tanh(c_t(:,t)).*o_t(:,t).*(1-o_t(:,t));
        dbi = dc_t.*c_t_current(:,t).*i_t(:,t).*(1-i_t(:,t));
        dbf = dc_t.*c_t(:,t-1).*f_t(:,t).*(1-f_t(:,t));
        dbc = dc_t.*i_t(:,t).*(1-(c_t_current(:,t)).^2);
      %更新
    Wf = Wf+dwf;
    Wi = Wi+dwi;
    Wc = Wc+dwc;
    Wo = Wo+dwo;

    Bf = Bf+dbf;
    Bi = Bi+dbi;
    Bc = Bc+dbc;
    Bo = Bo+dbo;

    Wy = Wy+dwy;
    By = By+dby;

  end



end

%% 更新
wf = wf+learningrateWf;
wi = wi+learningrate
Wi;
wc = wc+learningrateWc;
wo = wo+learningrate
Wo;

bf = bf+learningrate*Bf;
bi = bi+learningrate*Bi;
bc = bc+learningrate*Bc;
bo = bo+learningrate*Bo;

wy = wy+learningrate*Wy;
by = by+learningrate*By;

end

%% 测试
dataY = G4_4ring(201:250,:);
dataTest =(dataY-mean(dataY))/std(dataY);
for i = 1: 35
data_3(:,i) = dataTest(i:i+9,1);
end
for i = 1: 35
YTest(:,i) = dataTest(i+10:i+14,1);
end

%测试数据
datatest_in = data_3;%测试集输入
datatest_out = YTest;%测试集输出

%预测过程
for m = 1:35
if m==1
Zx1_T = wf*[zeros(out,1);datatest_in(:,m)]+bf;
Zx2_T = wi*[zeros(out,1);datatest_in(:,m)]+bi;
Zx3_T = wc*[zeros(out,1);datatest_in(:,m)]+bc;
Zx4_T = wo*[zeros(out,1);datatest_in(:,m)]+bo;

  f_t_T(:,m) = Sigmoid(Zx1_T);%遗忘门

  i_t_T(:,m) = Sigmoid(Zx2_T);

  c_t_current_T(:,m) = tanh(Zx3_T);%输入门

  o_t_T(:,m) = Sigmoid(Zx4_T);

  

  c_t_T(:,m) = i_t_T(:,m).*c_t_current_T(:,m);%更新细胞状态
  h_t_T(:,m) = o_t_T(:,m).*Tanh(c_t_T(:,m));%输出门

  z_t_T(:,m) = wy*h_t_T(:,m) + by;
  Y_t(:,m) =Sigmoid(z_t_T(:,m));


  e_test(:,m) = Y_t(:,m) - datatest_out(:,m);%误差

else
Zx1_T = wf*[h_t_T(:,m-1);datatrain_in(:,m)]+bf;
Zx2_T = wi*[h_t_T(:,m-1);datatrain_in(:,m)]+bi;
Zx3_T = wc*[h_t_T(:,m-1);datatrain_in(:,m)]+bc;
Zx4_T = wo*[h_t_T(:,m-1);datatrain_in(:,m)]+bo;

  f_t_T(:,m) = Sigmoid(Zx1_T);%遗忘门
    
  i_t_T(:,m) = Sigmoid(Zx2_T);
   
  c_t_current_T(:,m) = Tanh(Zx3_T);%输入门

  o_t_T(:,m) = Sigmoid(Zx4_T);
 
    
  c_t_T(:,m) = f_t_T(:,m).*c_t_T(:,m-1)+i_t_T(:,m).*c_t_current_T(:,m);%更新细胞状态
   
   
  h_t_T(:,m) = o_t_T(:,m).*Tanh(c_t_T(:,m));%输出

  z_t_T(:,m) = wy*h_t_T(:,m) + by;

  Y_t(:,m) = Sigmoid(z_t_T(:,m));

  e_test(:,m) = Y_t(:,m) - datatest_out(:,m);%误差
    
    

end
end

YPred = Y_tstd(dataY) + mean(dataY)
YTest = YTest
std(dataY) + mean(dataY);
XTest = datatest_in*std(dataY)+ mean(dataY);
%% 画图
figure('Position', [10, 10, 900, 400]);
S = XTest(1,:);
Y = [S;YPred];
D1 = XTest(1,:);
D2 = XTest(2:10,35)';
d = [D1,D2];
for i= 1 : 35
%figure('Position', [10, 10, 900, 400]);
plot(d(1:40),'k','LineWidth',2)

hold on
%plot(i+2O0:i+2O9,XTest(:,i),'y','LineWidth',2)
%plot(i+210:i+214,YTes(:,i), '.-k','LineWidth',2)

plot(i:i+5,Y(:,i), 'x-r')%

end

  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2022-09-17 17:17
    关注
    • 给你找了一篇非常好的博客,你可以看看是否有帮助,链接:LSTM matlab实现
    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 9月17日
  • 修改了问题 9月17日
  • 创建了问题 9月17日

悬赏问题

  • ¥15 HFSS 中的 H 场图与 MATLAB 中绘制的 B1 场 部分对应不上
  • ¥15 如何在scanpy上做差异基因和通路富集?
  • ¥20 关于#硬件工程#的问题,请各位专家解答!
  • ¥15 关于#matlab#的问题:期望的系统闭环传递函数为G(s)=wn^2/s^2+2¢wn+wn^2阻尼系数¢=0.707,使系统具有较小的超调量
  • ¥15 FLUENT如何实现在堆积颗粒的上表面加载高斯热源
  • ¥30 截图中的mathematics程序转换成matlab
  • ¥15 动力学代码报错,维度不匹配
  • ¥15 Power query添加列问题
  • ¥50 Kubernetes&Fission&Eleasticsearch
  • ¥15 報錯:Person is not mapped,如何解決?