~~557 2021-09-09 08:56 采纳率: 76.7%
浏览 36
已结题

交叉验证求L1、L2的最优参数

请帮我看一下代码对不对,有没有逻辑上的错误

import pandas as pd
import numpy as np
from sklearn import model_selection, linear_model
from sklearn.linear_model import Lasso, LassoCV
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

data = pd.read_csv(r'D:\PyCharm\projects\')

#拆分为训练集和测试集
predictors = data.columns[2:]
x_train, x_test, y_train, y_test = model_selection.train_test_split(data[predictors], data.fuel,
                                                               test_size=0.25, random_state=1234)
#构造不同的lambda值
Lambdas = np.logspace(-5, -2, 200)
#设置交叉验证的参数,使用均方误差评估
lasso_cv = LassoCV(alphas=Lambdas, normalize=True, cv=10, max_iter=10000)
lasso_cv.fit(x_train, y_train)

#测试不同的α值对预测性能的影响
def test_lasso_alpha(*data):
    alphas = np.logspace(-5, -2, 200)
    MSE = []
    for i, alpha in enumerate(alphas):
        lassoRegression = linear_model.Lasso(alpha=alpha)
        lassoRegression.fit(x_train, y_train)
        lasso_pred = lasso_cv.predict(x_test)
        MSE.append(mean_squared_error(y_test, lasso_pred))
    return alphas, MSE

def show_plot(alphas, MSE):
    figure = plt.figure()
    ax = figure.add_subplot(1, 1, 1)
    ax.plot(alphas, MSE)
    ax.set_xlabel(r"$\alpha$")
    ax.set_ylabel(r"MSE")
    ax.set_xscale("log")
    ax.set_title("lasso")
    plt.show()

if __name__=='__main__':
    alphas, MSE = test_lasso_alpha(x_train, x_test, y_train, y_test)
    show_plot(alphas, MSE)



#基于最佳lambda值建模
lasso = Lasso(alpha=lasso_cv.alpha_, normalize=True, max_iter=10000)
lasso.fit(x_train, y_train)
#打印回归系数
print('最优参数:', lasso_cv.alpha_)
print(pd.Series(index=['Intercept']+x_train.columns.tolist(),
                data=[lasso.intercept_]+lasso.coef_.tolist()))

#模型评估
lasso_pred = lasso.predict(x_test)
#均方误差
MSE = mean_squared_error(y_test, lasso_pred)
print('均方误差:', MSE)

train_score = lasso.score(x_train, y_train)  # 模型对训练样本得准确性
test_score = lasso.score(x_test, y_test)  # 模型对测试集的准确性
print(train_score)
print(test_score)

  • 写回答

1条回答 默认 最新

  • 关注
    data = pd.read_csv(r'D:\PyCharm\projects\')
    

    r字符串的最后一个字符不能是 \ 建议改成

    data = pd.read_csv('D:\\PyCharm\\projects\\')
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 系统已结题 9月23日
  • 已采纳回答 9月15日
  • 创建了问题 9月9日

悬赏问题

  • ¥15 这个复选框什么作用?
  • ¥15 单通道放大电路的工作原理
  • ¥30 YOLO检测微调结果p为1
  • ¥20 求快手直播间榜单匿名采集ID用户名简单能学会的
  • ¥15 DS18B20内部ADC模数转换器
  • ¥15 做个有关计算的小程序
  • ¥15 MPI读取tif文件无法正常给各进程分配路径
  • ¥15 如何用MATLAB实现以下三个公式(有相互嵌套)
  • ¥30 关于#算法#的问题:运用EViews第九版本进行一系列计量经济学的时间数列数据回归分析预测问题 求各位帮我解答一下
  • ¥15 setInterval 页面闪烁,怎么解决