问题:模拟出来的函数拟不上数据,恳请大佬告知哪儿出了问题!!!
损失函数:L(w1,w2,b)=-∑(w1*x1+w2*x2+b)y
w1的梯度:-∑y*x1
w2的梯度: -∑y*x2
b的梯度: -∑y
学习率为 0.01
使用梯度下降法进行优化
import pandas
import matplotlib.pyplot as plt
import numpy as np
def ganzhiji(data_x1,data_x2,data_y):
w_1 = 0 # 参数w1
w_2 = 0 # 参数w2
b = 0 # 偏置
step = 0.01
threshold = 0.1
f_pre = -0.2
re_num = 0 # 循环次数
f_current = loss_function(w_1, w_2, b, data_x1, data_x2, data_y)
while abs(f_current - f_pre) > threshold and re_num <= 40000:
w_1, w_2, b = update(data_x1, data_x2, data_y, w_1, w_2, b, step)
f_pre = f_current
f_current = loss_function(w_1, w_2, b, data_x1, data_x2, data_y)
re_num = re_num + 1
# 训练完毕后计算精度
num = 0
for a in range(len(data_x1)):
if (w_1 * data_x1[a] + w_2 * data_x2[a] + b) * data_y[a] < 0:
num = num + 1
print("w1={},w2={},b={},精度为:{}".format(w_1, w_2, b, 1 - num/len(data_x1)))
return w_1,w_2,b
def loss_function(w_1, w_2, b, data_x1, data_x2, data_y): # 计算参数更新后损失函数的值
sum = 0
for a in range(len(data_x1)):
if (w_1 * data_x1[a] + w_2 * data_x2[a] + b) * data_y[a] < 0:
sum = sum - (w_1 * data_x1[a] + w_2 * data_x2[a] + b) * data_y[a]
return sum
def tidu(data_x1, data_x2, data_y, b): # 参数的梯度
t_w1 = 0
t_w2 = 0
t_b = 0
for a in range(len(data_x1)):
t_w1 = t_w1 - data_y[a] * data_x1[a]
t_w2 = t_w2 - data_y[a] * data_x2[a]
t_b =t_b - data_y[a]
return t_w1, t_w2, t_b
def update(data_x1, data_x2, data_y, w_1, w_2, b, step): # 将参数进行梯度下降
t_w1 , t_w2, t_b = tidu(data_x1, data_x2, data_y, b)
w_1 = w_1 - step * t_w1
w_2 =w_2 - step * t_w2
b = b - step * t_b
return w_1, w_2, b
if __name__=='__main__':
data = pandas.read_csv(r'C:\Users\科德的帝国\Desktop\ML_data.csv',engine='python')
data_x1 = data['x1']
data_x2 = data['x2']
data_y = data['y']
data_x1 = [a /10 for a in data_x1] # 将数据放缩
data_x2 = [a/10 for a in data_x2]
data_y = [a for a in data_y]
for a in range(len(data_y)): # 将数据文件中的类别y为0的值改为-1,以符合符号函数的输出
if data_y[a] == 0:
data_y[a] = -1
# 配置
plt.rcParams['font.sans-serif']=['SimHei'] # 显示中文
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
# 画散点图
for a in range(len(data_x1)):
if data_y[a] == 1:
plt.scatter(data_x1[a], data_x2[a],color='red')
else:
plt.scatter(data_x1[a], data_x2[a],color='blue')
# 画训练完毕的函数
w_1,w_2,b=ganzhiji(data_x1,data_x2,data_y)
x= np.arange(0,10,0.1)
y = [-(w_1*i+b)/w_2 for i in x]
plt.plot(x,y)
plt.show()
执行结果: