qq_42887487 2022-09-15 10:28 采纳率: 66.7%
浏览 60
已结题

matlab与pytorch拟合效果对比

为什么对于一个很简单的非线性拟合问题,就一个隐藏层+激活函数,pytorch与matlab对比效果差很多?

import torch
import numpy as np
import matplotlib.pyplot as plt

import scipy.io as scio
from torch.utils.data import random_split
from torch.optim import lr_scheduler

from torch.optim import lr_scheduler

from sklearn.metrics import r2_score

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'


def mse_cal(data_loader, net):
    """mse计算函数
    
    :param data_loader:加载好的数据
    :param net: 模型
    :return:根据输入的数据,输出其MSE计算结果
    """
    data = data_loader.dataset                # 还原Dataset类
    X = data[:][0]                            # 还原数据的特征
    y = data[:][1]                            # 还原数据的标签
    yhat = net(X)
    return F.mse_loss(yhat, y)






class GenData(Dataset):
    def __init__(self, features, labels):           
        self.features = features                    
        self.labels = labels                       
        self.lens = len(features)                  

    def __getitem__(self, index):
        return self.features[index,:],self.labels[index]    

    def __len__(self):
        return self.lens
    

    
    
    
def split_loader(features, labels, batch_size=10, rate=0.7):
    data = GenData(features, labels) 
    num_train = int(data.lens * rate)
    num_test = data.lens - num_train
    data_train, data_test = random_split(data, [num_train, num_test])
    train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(data_test, batch_size=batch_size, shuffle=False)
    return(train_loader, test_loader)


class Model(nn.Module):
    def __init__(self,in_features=2,hidden1=6,out_features=1,BN_Model=None,act_fun=torch.relu):
        super(Model,self).__init__()
        self.linear1=nn.Linear(in_features,hidden1)
        self.normalize1=nn.BatchNorm1d(hidden1)
        self.linear2=nn.Linear(hidden1,out_features)
        self.BN_Model=BN_Model
        self.act_fun=act_fun
    def forward(self,x):
        if self.BN_Model=='pre':
            x=self.act_fun(self.normalize1(self.linear1(x)))
            output=self.linear2(x)
        if self.BN_Model==None:
            x=self.act_fun(self.linear1(x))
            output=self.linear2(x)
        
        return output


torch.manual_seed(420)
input_=torch.rand(size=(2000,2))
output_=(input_[:,0]*input_[:,1]).reshape(-1,1)


torch.manual_seed(420)
train_loader,test_loader=split_loader(input_,output_,batch_size=50,rate=0.75)


torch.manual_seed(24)
net1=Model(BN_Model='pre',hidden1=8,act_fun=torch.sigmoid)


def fit_rec_sc(net,
              criterion,
              optimizer,
              train_data,
              test_data,
              scheduler,
              epochs=3,
              cla=False,
              eva=mse_cal):
    train_l=[]
    test_l=[]
    
    for epoch in range(epochs):
        net.train()
        for X,y in train_data:
            if cla==True:
                y=y.flatten().long()
            yhat=net.forward(X)
            loss=criterion(yhat,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()
        net.eval()
        train_l.append(eva(train_data,net).detach())
        test_l.append(eva(test_data,net).detach())
    return train_l,test_l


lr_lambda=lambda epoch:0.95**epoch

optimizer=torch.optim.Adam(net1.parameters(),lr=0.001)

scheduler=lr_scheduler.LambdaLR(optimizer,lr_lambda)


train_l,test_l=fit_rec_sc(net=net1,
                         criterion=nn.MSELoss(),
                         optimizer=optimizer,
                         train_data=train_loader,
                         test_data=test_loader,
                         scheduler=scheduler,
                         epochs=400,
                         cla=False,
                         eva=mse_cal)


plt.plot(train_l,label='train_mse')
plt.plot(test_l,label='test_mse')
plt.xlabel('epochs')
plt.ylabel('MSE')
plt.legend(loc=1)
plt.show()


mse_cal(test_loader,net1)


yhat=net1.forward(test_loader.dataset[:][0])
y=test_loader.dataset[:][1]

plt.plot(yhat.detach(),'bo-',label='pred')
plt.plot(y,'ro-',label='real')


r2_score(y.detach().numpy(),yhat.detach().numpy())

matlab代码



%% 初始化
clear
close all
clc

%% 读取数据
input=rand(2,2000);
output=input(1,:).*input(2,:);

%% 训练集、测试集
input_train = input(:,1:1500);
output_train =output(1:1500);
input_test =input(:,1501:end);
output_test =output(1501:end);

%% 数据归一化
[inputn,inputps]=mapminmax(input_train,0,1);
[outputn,outputps]=mapminmax(output_train);
inputn_test=mapminmax('apply',input_test,inputps);

%% 构建BP神经网络
net=newff(inputn,outputn,8);

% 网络参数
net.trainParam.epochs=1000;         % 训练次数
net.trainParam.lr=0.01;                   % 学习速率
net.trainParam.goal=0.0000000001;        % 训练目标最小误差
% net.dividefcn='';
%% BP神经网络训练
net=train(net,inputn,outputn);

%% BP神经网络测试
an=sim(net,inputn_test); %用训练好的模型进行仿真 
test_simu=mapminmax('reverse',an,outputps); % 预测结果反归一化

error=test_simu-output_test;      %预测值和真实值的误差


%y1为预测值 y为实际值
R2=1 - (sum((output_test- test_simu).^2) / sum((output_test - mean(output_test)).^2));


%%真实值与预测值误差比较

disp(['r2误差为: ',num2str(R2)])

figure(1)
plot(output_test,'bo-')
hold on
plot(test_simu,'r*-')
hold on
plot(error,'square','MarkerFaceColor','b')
legend('期望值','预测值','误差')
xlabel('数据组数'),ylabel('值'),title('测试集预测值和期望值的误差对比'),set(gca,'fontsize',12)
%计算误差
[~,len]=size(output_test);
MAE1=sum(abs(error./output_test))/len;
MSE1=error*error'/len;
RMSE1=MSE1^(1/2);
disp(['-----------------------误差计算--------------------------'])
disp(['平均绝对误差MAE为:',num2str(MAE1)])
disp(['均方误差MSE为:       ',num2str(MSE1)])
disp(['均方根误差RMSE为:  ',num2str(RMSE1)])


img

img

分别为pytorch/matlab效果图

  • 写回答

1条回答 默认 最新

  • submarineas 2022-09-15 11:53
    关注

    首先题主的epoch,matlab比pytorch多了600轮,一个1k,一个400,另外就是lr,对不上。然后可能还有其它值,没有细看了。那一个步长小,学习慢,轮数少,和一个相反的,没有可比性,可能一个欠拟合,一个过拟合

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 系统已结题 9月25日
  • 已采纳回答 9月17日
  • 创建了问题 9月15日

悬赏问题

  • ¥170 如图所示配置eNSP
  • ¥20 docker里部署springboot项目,访问不到扬声器
  • ¥15 netty整合springboot之后自动重连失效
  • ¥15 悬赏!微信开发者工具报错,求帮改
  • ¥20 wireshark抓不到vlan
  • ¥20 关于#stm32#的问题:需要指导自动酸碱滴定仪的原理图程序代码及仿真
  • ¥20 设计一款异域新娘的视频相亲软件需要哪些技术支持
  • ¥15 stata安慰剂检验作图但是真实值不出现在图上
  • ¥15 c程序不知道为什么得不到结果
  • ¥15 键盘指令混乱情况下的启动盘系统重装