强扭的甜不瓜 2022-04-22 08:11 采纳率: 75%
浏览 232
已结题

机器学习pycharm线性回归代码讲解

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from linear_regression import LinearRegression

data = pd.read_csv('../data/world-happiness-report-2017.csv')

train_data = data.sample(frac = 0.8)
test_data = data.drop(train_data.index)

input_param_name = 'Economy..GDP.per.Capita.'
output_param_name = 'Happiness.Score'

x_train = train_data[[input_param_name]].values
y_train = train_data[[output_param_name]].values

x_test = test_data[input_param_name].values
y_test = test_data[output_param_name].values

plt.scatter(x_train,y_train,label='Train data')
plt.scatter(x_test,y_test,label='test data')
plt.xlabel(input_param_name)
plt.ylabel(output_param_name)
plt.title('Happy')
plt.legend()
plt.show()

num_iterations = 500
learning_rate = 0.01

linear_regression = LinearRegression(x_train,y_train)
(theta,cost_history) = linear_regression.train(learning_rate,num_iterations)

print ('开始时的损失:',cost_history[0])
print ('训练后的损失:',cost_history[-1])

plt.plot(range(num_iterations),cost_history)
plt.xlabel('Iter')
plt.ylabel('cost')
plt.title('GD')
plt.show()

predictions_num = 100

x_predictions = np.linspace(x_train.min(),x_train.max(),predictions_num).reshape(predictions_num,1)
y_predictions = linear_regression.predict(x_predictions)

plt.scatter(x_train,y_train,label='Train data')
plt.scatter(x_test,y_test,label='test data')
plt.plot(x_predictions,y_predictions,'r',label = 'Prediction')
plt.xlabel(input_param_name)
plt.ylabel(output_param_name)
plt.title('Happy')
plt.legend()
plt.show()
有没有会的给我简单讲解一下每段代码都是干什么的 谢谢啦

展开全部

  • 写回答

2条回答 默认 最新

  • 关注

    注释给你写好了,如有帮助,请点击我评论上方【采纳该答案】按钮支持一下,谢谢!

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    
    from linear_regression import LinearRegression#前几行都是导入包
    
    data = pd.read_csv('../data/world-happiness-report-2017.csv')#读取csv文件到data
    
    # 得到训练和测试数据
    train_data = data.sample(frac = 0.8)#frac 抽取行的比例 例如frac=0.8,就是抽取其中80%
    test_data = data.drop(train_data.index)#使用drop函数删除表中index
    
    input_param_name = 'Economy..GDP.per.Capita.'# 输入特征名字
    output_param_name = 'Happiness.Score'# 输出特征名字
    
    x_train = train_data[[input_param_name]].values
    # .values表示转换成ndarray格式 [input_param_name]表示列值
    y_train = train_data[[output_param_name]].values
    # .values表示转换成ndarray格式 [output_par  am_name]表示列值
    
    x_test = test_data[input_param_name].values#上面是训练集,这是测试集
    y_test = test_data[output_param_name].values
    
    # 散点图绘制
    plt.scatter(x_train,y_train,label='Train data')#
    plt.scatter(x_test,y_test,label='test data')
    plt.xlabel(input_param_name)
    plt.ylabel(output_param_name)
    plt.title('Happy')
    plt.legend()
    plt.show()
    
    num_iterations = 500# 迭代次数
    learning_rate = 0.01# 学习率
    
    linear_regression = LinearRegression(x_train,y_train)
    (theta,cost_history) = linear_regression.train(learning_rate,num_iterations)
    # 调用train模块传入学习率和和迭代次数
    
    print ('开始时的损失:',cost_history[0])
    # cost_history[0]表示开始的
    print ('训练后的损失:',cost_history[-1])
    # cost_history[-1]表示最后的那次
    plt.plot(range(num_iterations),cost_history)
    plt.xlabel('Iter')
    plt.ylabel('cost')
    plt.title('GD')
    plt.show()
    
    predictions_num = 100
    
    x_predictions = np.linspace(x_train.min(),x_train.max(),predictions_num).reshape(predictions_num,1)
    y_predictions = linear_regression.predict(x_predictions)
    
    plt.scatter(x_train,y_train,label='Train data')
    plt.scatter(x_test,y_test,label='test data')
    plt.plot(x_predictions,y_predictions,'r',label = 'Prediction')
    plt.xlabel(input_param_name)
    plt.ylabel(output_param_name)
    plt.title('Happy')
    plt.legend()
    plt.show()
    
    

    展开全部

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

    如有帮助,请点击我评论上方【采纳该答案】按钮支持一下,谢谢!以后有什么问题可以互相交流。
    参考https://blog.csdn.net/weixin_53660567/article/details/123048523

    2
    回复
    强扭的甜不瓜 回复 CSDN专家-深度学习进阶 2022-04-22 15:20

    谢谢大佬!

    回复
    zerozero353 回复 CSDN专家-深度学习进阶 2024-03-15 12:33

    那数据库怎么建立?

    回复
查看更多回答(1条)
编辑
预览

报告相同问题?

问题事件

  • 系统已结题 4月29日
  • 已采纳回答 4月22日
  • 创建了问题 4月22日
手机看
程序员都在用的中文IT技术交流社区

程序员都在用的中文IT技术交流社区

专业的中文 IT 技术社区,与千万技术人共成长

专业的中文 IT 技术社区,与千万技术人共成长

关注【CSDN】视频号,行业资讯、技术分享精彩不断,直播好礼送不停!

关注【CSDN】视频号,行业资讯、技术分享精彩不断,直播好礼送不停!

客服 返回
顶部