Hold_C 2020-08-29 12:38 采纳率: 20%
浏览 100
已采纳

感知机的实现(二分类)——模拟出来的直线斜率没问题,但是截距有问题

问题:模拟出来的函数拟不上数据,恳请大佬告知哪儿出了问题!!!

损失函数:L(w1,w2,b)=-∑(w1*x1+w2*x2+b)y
w1的梯度:-∑y*x1
w2的梯度: -∑y*x2
b的梯度: -∑y
学习率为 0.01
使用梯度下降法进行优化

import pandas
import matplotlib.pyplot as plt
import numpy as np
def ganzhiji(data_x1,data_x2,data_y):
    w_1 = 0   # 参数w1
    w_2 = 0   # 参数w2   
    b = 0     # 偏置
    step = 0.01
    threshold = 0.1
    f_pre = -0.2
    re_num = 0  # 循环次数
    f_current = loss_function(w_1, w_2, b, data_x1, data_x2, data_y)
    while abs(f_current - f_pre) > threshold and re_num <= 40000:
        w_1, w_2, b = update(data_x1, data_x2, data_y, w_1, w_2, b, step)
        f_pre = f_current
        f_current = loss_function(w_1, w_2, b, data_x1, data_x2, data_y)
        re_num = re_num + 1

    # 训练完毕后计算精度
    num = 0
    for a in range(len(data_x1)):
        if (w_1 * data_x1[a] + w_2 * data_x2[a] + b) * data_y[a] < 0:
            num = num + 1
    print("w1={},w2={},b={},精度为:{}".format(w_1, w_2, b, 1 - num/len(data_x1)))
    return w_1,w_2,b

def loss_function(w_1, w_2, b, data_x1, data_x2, data_y):  # 计算参数更新后损失函数的值
   sum = 0
   for a in range(len(data_x1)):
       if (w_1 * data_x1[a] + w_2 * data_x2[a] + b) * data_y[a] < 0:
           sum = sum - (w_1 * data_x1[a] + w_2 * data_x2[a] + b) * data_y[a]
   return sum

def tidu(data_x1, data_x2, data_y, b):  # 参数的梯度
    t_w1 = 0
    t_w2 = 0
    t_b = 0
    for a in range(len(data_x1)):
        t_w1 = t_w1 - data_y[a] * data_x1[a]
        t_w2 = t_w2 - data_y[a] * data_x2[a]
        t_b =t_b - data_y[a]
    return t_w1, t_w2, t_b

def update(data_x1, data_x2, data_y, w_1, w_2, b, step):  # 将参数进行梯度下降
    t_w1 , t_w2, t_b = tidu(data_x1, data_x2, data_y, b)
    w_1 = w_1 - step * t_w1
    w_2 =w_2 - step * t_w2
    b = b - step * t_b
    return w_1, w_2, b

if __name__=='__main__':
    data = pandas.read_csv(r'C:\Users\科德的帝国\Desktop\ML_data.csv',engine='python')
    data_x1 = data['x1']
    data_x2 = data['x2']
    data_y = data['y']

    data_x1 = [a /10 for a in data_x1]         # 将数据放缩
    data_x2 = [a/10 for a in data_x2]
    data_y = [a for a in data_y]
    for a in range(len(data_y)):               # 将数据文件中的类别y为0的值改为-1,以符合符号函数的输出 
       if data_y[a] == 0:
          data_y[a] = -1

    # 配置
    plt.rcParams['font.sans-serif']=['SimHei'] # 显示中文
    plt.rcParams['axes.unicode_minus']=False #用来正常显示负号

    # 画散点图
    for a in range(len(data_x1)):
       if data_y[a] == 1:
            plt.scatter(data_x1[a], data_x2[a],color='red')
       else:
            plt.scatter(data_x1[a], data_x2[a],color='blue')

    # 画训练完毕的函数        
    w_1,w_2,b=ganzhiji(data_x1,data_x2,data_y)
    x= np.arange(0,10,0.1)
    y = [-(w_1*i+b)/w_2 for i in x]
    plt.plot(x,y)
    plt.show()

执行结果:
图片说明

  • 写回答

1条回答 默认 最新

  • dabocaiqq 2020-08-29 13:41
    关注
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

悬赏问题

  • ¥30 python代码,帮调试
  • ¥15 #MATLAB仿真#车辆换道路径规划
  • ¥15 java 操作 elasticsearch 8.1 实现 索引的重建
  • ¥15 数据可视化Python
  • ¥15 要给毕业设计添加扫码登录的功能!!有偿
  • ¥15 kafka 分区副本增加会导致消息丢失或者不可用吗?
  • ¥15 微信公众号自制会员卡没有收款渠道啊
  • ¥100 Jenkins自动化部署—悬赏100元
  • ¥15 关于#python#的问题:求帮写python代码
  • ¥20 MATLAB画图图形出现上下震荡的线条