~~557 2021-09-09 08:56 采纳率: 76.7%
浏览 35
已结题

交叉验证求L1、L2的最优参数

请帮我看一下代码对不对,有没有逻辑上的错误

import pandas as pd
import numpy as np
from sklearn import model_selection, linear_model
from sklearn.linear_model import Lasso, LassoCV
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

data = pd.read_csv(r'D:\PyCharm\projects\')

#拆分为训练集和测试集
predictors = data.columns[2:]
x_train, x_test, y_train, y_test = model_selection.train_test_split(data[predictors], data.fuel,
                                                               test_size=0.25, random_state=1234)
#构造不同的lambda值
Lambdas = np.logspace(-5, -2, 200)
#设置交叉验证的参数,使用均方误差评估
lasso_cv = LassoCV(alphas=Lambdas, normalize=True, cv=10, max_iter=10000)
lasso_cv.fit(x_train, y_train)

#测试不同的α值对预测性能的影响
def test_lasso_alpha(*data):
    alphas = np.logspace(-5, -2, 200)
    MSE = []
    for i, alpha in enumerate(alphas):
        lassoRegression = linear_model.Lasso(alpha=alpha)
        lassoRegression.fit(x_train, y_train)
        lasso_pred = lasso_cv.predict(x_test)
        MSE.append(mean_squared_error(y_test, lasso_pred))
    return alphas, MSE

def show_plot(alphas, MSE):
    figure = plt.figure()
    ax = figure.add_subplot(1, 1, 1)
    ax.plot(alphas, MSE)
    ax.set_xlabel(r"$\alpha$")
    ax.set_ylabel(r"MSE")
    ax.set_xscale("log")
    ax.set_title("lasso")
    plt.show()

if __name__=='__main__':
    alphas, MSE = test_lasso_alpha(x_train, x_test, y_train, y_test)
    show_plot(alphas, MSE)



#基于最佳lambda值建模
lasso = Lasso(alpha=lasso_cv.alpha_, normalize=True, max_iter=10000)
lasso.fit(x_train, y_train)
#打印回归系数
print('最优参数:', lasso_cv.alpha_)
print(pd.Series(index=['Intercept']+x_train.columns.tolist(),
                data=[lasso.intercept_]+lasso.coef_.tolist()))

#模型评估
lasso_pred = lasso.predict(x_test)
#均方误差
MSE = mean_squared_error(y_test, lasso_pred)
print('均方误差:', MSE)

train_score = lasso.score(x_train, y_train)  # 模型对训练样本得准确性
test_score = lasso.score(x_test, y_test)  # 模型对测试集的准确性
print(train_score)
print(test_score)

  • 写回答

1条回答 默认 最新

  • 关注
    data = pd.read_csv(r'D:\PyCharm\projects\')
    

    r字符串的最后一个字符不能是 \ 建议改成

    data = pd.read_csv('D:\\PyCharm\\projects\\')
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 系统已结题 9月23日
  • 已采纳回答 9月15日
  • 创建了问题 9月9日

悬赏问题

  • ¥15 微信会员卡等级和折扣规则
  • ¥15 微信公众平台自制会员卡可以通过收款码收款码收款进行自动积分吗
  • ¥15 随身WiFi网络灯亮但是没有网络,如何解决?
  • ¥15 gdf格式的脑电数据如何处理matlab
  • ¥20 重新写的代码替换了之后运行hbuliderx就这样了
  • ¥100 监控抖音用户作品更新可以微信公众号提醒
  • ¥15 UE5 如何可以不渲染HDRIBackdrop背景
  • ¥70 2048小游戏毕设项目
  • ¥20 mysql架构,按照姓名分表
  • ¥15 MATLAB实现区间[a,b]上的Gauss-Legendre积分