想成为博客专家的渣渣 2020-02-22 14:04 采纳率: 50%
浏览 643
已采纳

关于机器学习梯度下降求 w 和 b 的问题

# 类目的求解斜率和截距
class Linear_model(object):
    def __init__(self):
        self.w = np.random.randn(1)[0]

        self.b = np.random.randn(1)[0]
        print('----------------------起始随机生成的斜率和截距',self.w,self.b)

#   model就是方程f(x) = wx + b
    def model(self,x):
        return self.w * x + self.b

#   线性问题,原理都是最小二乘法
    def loss(self,x,y):
#         方程中几个未知数???
        cost = (y - self.model(x))**2

#         求偏导数 ,把其他的都当成已知数,求一个未知数的导数
#         导数是偏导数的一种特殊形式
        g_w = 2*(y - self.model(x))*(-x)
        g_b = 2*(y - self.model(x))*(-1)
        return g_w,g_b

#     梯度下降
    def gradient_descend(self,g_w,g_b,step = 0.01):
#         更新新的斜率和截距
        self.w = self.w - g_w*step
        self.b = self.b - g_b*step
        print('----------------------',self.w,self.b)

    def fit(self,X,y):
        w_last = self.w + 1
        b_last = self.b + 1
        precision = 0.00001
        max_count = 3000
        count = 0
        while True:
            if (np.abs(self.w - w_last) < precision) and (np.abs(self.b - b_last) < precision):
                break

            if count > max_count:
                break

#             更新斜率和截距
            g_w = 0
            g_b = 0
            size = X.shape[0]
            for xi,yi in zip(X,y):
                g_w += self.loss(xi,yi)[0]/size
                g_b += self.loss(xi,yi)[1]/size

            self.gradient_descend(g_w,g_b)
            count += 1

    def coef_(self):
            return self.w

    def intercept_(self):
            return self.b

请问
def fit(self,X,y):
w_last = self.w + 1
b_last = self.b + 1
这里为什么

w_last = self.w + 1
b_last = self.b + 1

加一是什么意思

  • 写回答

2条回答 默认 最新

  • 九洲歌同 2020-02-22 14:54
    关注

    fit函数里面的while循环里的第一个if语句是想判断当前求得的self.w的精度,如果self.w和b的改变小于precision即认为精度达到要求,退出循环!

    而第一次while循环时明显没有可以拿来比较的,所以自己定义一个与self.w差距为1的w_last 来保证if判断

    实际上你可以改成2,3,4,0.5都可以,随意。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

悬赏问题

  • ¥15 用js遍历数据并对非空元素添加css样式
  • ¥15 使用autodl云训练,希望有直接运行的代码(关键词-数据集)
  • ¥50 python写segy数据出错
  • ¥20 关于线性结构的问题:希望能从头到尾完整地帮我改一下,困扰我很久了
  • ¥30 3D多模态医疗数据集-视觉问答
  • ¥20 设计一个二极管稳压值检测电路
  • ¥15 内网办公电脑进行向日葵
  • ¥15 如何输入双曲线的参数a然后画出双曲线?我输入处理函数加上后就没有用了,不知道怎么回事去掉后双曲线可以画出来
  • ¥15 soildworks装配体的尺寸问题
  • ¥100 有偿寻云闪付SDK转URL技术