The___Shy_ 2022-01-27 14:14 采纳率: 57.1%
浏览 37

机器学习线性回归中遇到的问题

问题遇到的现象和发生背景

关于一元线性回归y=wx+b中 w 和 b 的初始化问题
实现3.1234
x+2.98的线性回归

问题相关代码,请勿粘贴截图
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

np.random.seed(5)
x_data=np.linspace(0,100,500)                                                                      //生成[0,100]间有500个数据的等差数列
y_data=3.1234*x_data+2.980+np.random.randn(*x_data.shape)*0.5          //加入噪声

x_data=tf.cast(x_data,dtype=tf.float32)
y_data=tf.cast(y_data,dtype=tf.float32)

def model(x,w,b):
    return tf.multiply(x,w)+b
//损失函数
def MSEloss(x,y,w,b):
    err=y-model(x,w,b)
    sqrerr=tf.square(err)
    return tf.reduce_mean(sqrerr)
//初始化
w=tf.Variable(1.0,tf.float32)
b=tf.Variable(1.0,tf.float32)

traning_epoches=10
step=0
display_step=20
loss_list=[]
learning_rate=0.0001
optimizer=tf.optimizers.SGD(learning_rate)
//训练模型
for epoch in range(traning_epoches):
    for nx,ny in zip(x_data,y_data):
        with tf.GradientTape() as tape:
            loss=MSEloss(nx,ny,w,b)
            loss_list.append(loss)
            delta=tape.gradient(loss,[w,b])
        optimizer.apply_gradients(zip(delta,[w,b]))
        step=step+1
        if step % display_step==0:
            print("Traning Epoch:",'%02d' %(epoch+1),"Setp: %03d" %(step),"Loss=%.6f" %(loss))
    plt.plot(x_data,w.numpy()*x_data+b.numpy())
x_test=5.79
prd=model(x_test,w.numpy(),b.numpy())
print("预测值:%f" %prd)
target=3.1234*x_test+2.98
print("目标值:%f" %target)

运行结果及报错内容

当变量b的初值设为1.0的时候,预测出的结果如下
预测值:18.959854
目标值:21.064486
w= <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.093086>
b= <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.050886>

当变量b的初值设为3.0,接近2.98,即b=tf.Variable(3.0,tf.float32)
预测结果如下
w= <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.07139>
b= <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.222441>
预测值:21.005789
目标值:21.064486

为什么b的初值设为和真实值差不多,训练结果会更好,但是不是说初值设置多少是无影响的吗?
是哪里有问题?

我的解答思路和尝试过的方法
我想要达到的结果
  • 写回答

1条回答 默认 最新

  • AnFany 2022-01-28 10:12
    关注

    训练次数多一点就没差别了

    评论

报告相同问题?

问题事件

  • 创建了问题 1月27日

悬赏问题

  • ¥30 为什么会失败呢,该如何调整
  • ¥50 如何在不能联网影子模式下的电脑解决usb锁
  • ¥20 服务器redhat5.8网络问题
  • ¥15 如何利用c++ MFC绘制复杂网络多层图
  • ¥20 要做柴油机燃烧室优化 需要保持压缩比不变 请问怎么用AVL fire ESE软件里面的 compensation volume 来使用补偿体积来保持压缩比不变
  • ¥15 python螺旋图像
  • ¥15 算能的sail库的运用
  • ¥15 'Content-Type': 'application/x-www-form-urlencoded' 请教 这种post请求参数,该如何填写??重点是下面那个冒号啊
  • ¥15 找代写python里的jango设计在线书店
  • ¥15 请教如何关于Msg文件解析