loncco 2022-05-09 17:48
浏览 11
已结题

用神经网络拟合曲线失败了,求解答

不知道为什么自己弄出的测试数据都是可以拟合的,但是用到实际的数据就不行了。


from __future__ import print_function
import numpy as np
import torch
import torch.autograd
from torch.autograd import Variable
import matplotlib.pyplot as plt
import scipy.io as sio


p = sio.loadmat('C:\long\文献\ToStudents\DelayExtractionCode\S13PY.mat')
h = p['Data_all']
H = torch.from_numpy(np.real(h))
# H = H.float()
Freq = p['Freq_all']
f = torch.from_numpy(Freq)


# 声明模型
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden1, bias=False)
        self.relu = torch.nn.ELU()
        self.predict = torch.nn.Linear(n_hidden, n_output, bias=False)

    def forward(self, x):
        x = self.hidden(x)
![img](https://img-mid.csdnimg.cn/release/static/image/mid/ask/879036980256179.png "#left")

        x = self.relu(x)
        x = self.predict(x)
        # out = F.log_softmax(x)
        return x


net = Net(n_feature=1, n_hidden=100,  n_output=1)
# print(net)
net = net.double()

print(net)
loss_func = torch.nn.MSELoss()
optimizer = torch.optim.RMSprop(net.parameters(), lr=0.0001)

for t in range(1000):

    x, y = Variable(f), Variable(H)
    optimizer.zero_grad()                                  
    prediction = net(x)                                   
    loss = loss_func(prediction, y)    
    loss.backward()                                        
    optimizer.step()                                      

    if t % 5 == 0:
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)

plt.ioff()
plt.show()
  • 写回答

0条回答 默认 最新

    报告相同问题?

    问题事件

    • 系统已结题 5月17日
    • 创建了问题 5月9日

    悬赏问题

    • ¥15 HFSS 中的 H 场图与 MATLAB 中绘制的 B1 场 部分对应不上
    • ¥15 如何在scanpy上做差异基因和通路富集?
    • ¥20 关于#硬件工程#的问题,请各位专家解答!
    • ¥15 关于#matlab#的问题:期望的系统闭环传递函数为G(s)=wn^2/s^2+2¢wn+wn^2阻尼系数¢=0.707,使系统具有较小的超调量
    • ¥15 FLUENT如何实现在堆积颗粒的上表面加载高斯热源
    • ¥30 截图中的mathematics程序转换成matlab
    • ¥15 动力学代码报错,维度不匹配
    • ¥15 Power query添加列问题
    • ¥50 Kubernetes&Fission&Eleasticsearch
    • ¥15 報錯:Person is not mapped,如何解決?