Hold_C 2020-08-29 12:38 采纳率: 20%
浏览 100
已采纳

感知机的实现(二分类)——模拟出来的直线斜率没问题,但是截距有问题

问题:模拟出来的函数拟不上数据,恳请大佬告知哪儿出了问题!!!

损失函数:L(w1,w2,b)=-∑(w1*x1+w2*x2+b)y
w1的梯度:-∑y*x1
w2的梯度: -∑y*x2
b的梯度: -∑y
学习率为 0.01
使用梯度下降法进行优化

import pandas
import matplotlib.pyplot as plt
import numpy as np
def ganzhiji(data_x1,data_x2,data_y):
    w_1 = 0   # 参数w1
    w_2 = 0   # 参数w2   
    b = 0     # 偏置
    step = 0.01
    threshold = 0.1
    f_pre = -0.2
    re_num = 0  # 循环次数
    f_current = loss_function(w_1, w_2, b, data_x1, data_x2, data_y)
    while abs(f_current - f_pre) > threshold and re_num <= 40000:
        w_1, w_2, b = update(data_x1, data_x2, data_y, w_1, w_2, b, step)
        f_pre = f_current
        f_current = loss_function(w_1, w_2, b, data_x1, data_x2, data_y)
        re_num = re_num + 1

    # 训练完毕后计算精度
    num = 0
    for a in range(len(data_x1)):
        if (w_1 * data_x1[a] + w_2 * data_x2[a] + b) * data_y[a] < 0:
            num = num + 1
    print("w1={},w2={},b={},精度为:{}".format(w_1, w_2, b, 1 - num/len(data_x1)))
    return w_1,w_2,b

def loss_function(w_1, w_2, b, data_x1, data_x2, data_y):  # 计算参数更新后损失函数的值
   sum = 0
   for a in range(len(data_x1)):
       if (w_1 * data_x1[a] + w_2 * data_x2[a] + b) * data_y[a] < 0:
           sum = sum - (w_1 * data_x1[a] + w_2 * data_x2[a] + b) * data_y[a]
   return sum

def tidu(data_x1, data_x2, data_y, b):  # 参数的梯度
    t_w1 = 0
    t_w2 = 0
    t_b = 0
    for a in range(len(data_x1)):
        t_w1 = t_w1 - data_y[a] * data_x1[a]
        t_w2 = t_w2 - data_y[a] * data_x2[a]
        t_b =t_b - data_y[a]
    return t_w1, t_w2, t_b

def update(data_x1, data_x2, data_y, w_1, w_2, b, step):  # 将参数进行梯度下降
    t_w1 , t_w2, t_b = tidu(data_x1, data_x2, data_y, b)
    w_1 = w_1 - step * t_w1
    w_2 =w_2 - step * t_w2
    b = b - step * t_b
    return w_1, w_2, b

if __name__=='__main__':
    data = pandas.read_csv(r'C:\Users\科德的帝国\Desktop\ML_data.csv',engine='python')
    data_x1 = data['x1']
    data_x2 = data['x2']
    data_y = data['y']

    data_x1 = [a /10 for a in data_x1]         # 将数据放缩
    data_x2 = [a/10 for a in data_x2]
    data_y = [a for a in data_y]
    for a in range(len(data_y)):               # 将数据文件中的类别y为0的值改为-1,以符合符号函数的输出 
       if data_y[a] == 0:
          data_y[a] = -1

    # 配置
    plt.rcParams['font.sans-serif']=['SimHei'] # 显示中文
    plt.rcParams['axes.unicode_minus']=False #用来正常显示负号

    # 画散点图
    for a in range(len(data_x1)):
       if data_y[a] == 1:
            plt.scatter(data_x1[a], data_x2[a],color='red')
       else:
            plt.scatter(data_x1[a], data_x2[a],color='blue')

    # 画训练完毕的函数        
    w_1,w_2,b=ganzhiji(data_x1,data_x2,data_y)
    x= np.arange(0,10,0.1)
    y = [-(w_1*i+b)/w_2 for i in x]
    plt.plot(x,y)
    plt.show()

执行结果:
图片说明

  • 写回答

1条回答 默认 最新

  • dabocaiqq 2020-08-29 13:41
    关注
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

悬赏问题

  • ¥15 安卓adb backup备份应用数据失败
  • ¥15 eclipse运行项目时遇到的问题
  • ¥15 关于#c##的问题:最近需要用CAT工具Trados进行一些开发
  • ¥15 南大pa1 小游戏没有界面,并且报了如下错误,尝试过换显卡驱动,但是好像不行
  • ¥15 没有证书,nginx怎么反向代理到只能接受https的公网网站
  • ¥50 成都蓉城足球俱乐部小程序抢票
  • ¥15 yolov7训练自己的数据集
  • ¥15 esp8266与51单片机连接问题(标签-单片机|关键词-串口)(相关搜索:51单片机|单片机|测试代码)
  • ¥15 电力市场出清matlab yalmip kkt 双层优化问题
  • ¥30 ros小车路径规划实现不了,如何解决?(操作系统-ubuntu)