weixin_62663500 2025-05-29 11:57 采纳率: 0%
浏览 9

logistic回归X与logitP线性关系的图像观察

在logitic回归中,想通过图像查看连续性变量x与logitP是否满足线性关系,通过python实现,请Deepseek写了代码,请帮忙看一下这个代码有问题吗?谢谢!

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.api as sm
from sklearn.linear_model import LogisticRegression
from statsmodels.nonparametric.smoothers_lowess import lowess

def plot_logit_linearity(data, x_col, y_col, frac=0.3):
 """
    绘制连续性自变量与logitP的关系图,检查线性假设
    
    参数:
    - data: 包含自变量和因变量的DataFrame
    - x_col: 要检查的连续性自变量列名
    - y_col: 二分类因变量列名(0/1)
    - frac: LOWESS平滑的窗口大小(0-1之间)
    """
    # 准备数据
    X = data[[x_col]].values
    y = data[y_col].values
    
    # 拟合逻辑回归模型
    logit_model = LogisticRegression()
    logit_model.fit(X, y)
    
    # 计算预测概率和logit转换
    pred_prob = logit_model.predict_proba(X)[:, 1]
    logit_p = np.log(pred_prob / (1 - pred_prob))
    
    # 创建图形
    plt.figure(figsize=(10, 6))
    
    # 绘制原始数据点(抖动处理,便于观察密度)
    plt.scatter(X, logit_p, alpha=0.5, color='blue', label='Data points', s=10)
    
    # 计算并绘制LOWESS平滑曲线
    lowess_curve = lowess(logit_p.flatten(), X.flatten(), frac=frac)
    plt.plot(lowess_curve[:, 0], lowess_curve[:, 1], 
             color='red', linewidth=2, label='LOWESS Smoothing')
    
    # 绘制线性拟合线
    coef = logit_model.coef_[0][0]
    intercept = logit_model.intercept_[0]
    x_vals = np.array([X.min(), X.max()])
    y_vals = intercept + coef * x_vals
    plt.plot(x_vals, y_vals, 'g--', linewidth=2, label='Linear Fit')
    
    # 添加图形元素
    plt.xlabel(f'Independent Variable: {x_col}', fontsize=12)
    plt.ylabel('Logit(P)', fontsize=12)
    plt.title(f'Linearity Check: {x_col} vs. Logit(P)', fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    
    plt.show()

# 示例使用
if __name__ == "__main__":
    # 生成示例数据
    np.random.seed(42)
    n_samples = 500
    data = pd.DataFrame({
        'age': np.random.normal(45, 15, n_samples),
        'disease': np.random.binomial(1, 0.5, n_samples)
    })
    # 确保年龄为正数
    data['age'] = data['age'].clip(lower=18, upper=90)
    
    # 调用函数绘制图形
    plot_logit_linearity(data, 'age', 'disease', frac=0.3)
  • 写回答

5条回答 默认 最新

  • 檀越@新空间 2025-05-29 11:58
    关注

    上午好☀️☀️☀️️
    本答案参考通义千问

    你的代码总体上是正确的,能够完成绘制连续性变量 (X) 和 (\text{logit}(P)) 的线性关系图的任务。不过,有一些地方可以优化或改进,以确保代码更健壮、可读性更高,并且符合最佳实践。以下是对代码的分析以及改进建议:


    代码分析

    1. 逻辑正确性

      • 使用 LogisticRegression 模型计算预测概率和 (\text{logit}(P)),这是合理的。
      • 使用 lowess 平滑曲线来检查线性关系,这是一种常见的方法。
    2. 潜在问题

      • 抖动处理plt.scatter 中没有对 (X) 值进行抖动处理,可能导致重叠的数据点遮挡彼此。
      • 数据范围问题x_vals 的生成基于 X.min()X.max(),但 X 是一维数组,需要先从 DataFrame 中提取出来。
      • 默认参数frac 参数的默认值为 0.3,可能需要根据具体数据调整。
      • 异常处理:如果输入数据包含缺失值或非数值类型,代码可能会报错。
    3. 改进建议

      • 添加数据预处理步骤,确保输入数据的质量。
      • 改进抖动处理,使散点图更清晰。
      • 提供更多的注释,增强代码的可读性。

    修改后的代码

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import statsmodels.api as sm
    from sklearn.linear_model import LogisticRegression
    from statsmodels.nonparametric.smoothers_lowess import lowess
    
    def plot_logit_linearity(data, x_col, y_col, frac=0.3):
        """
        绘制连续性自变量与logitP的关系图,检查线性假设
        
        参数:
        - data: 包含自变量和因变量的DataFrame
        - x_col: 要检查的连续性自变量列名
        - y_col: 二分类因变量列名(0/1)
        - frac: LOWESS平滑的窗口大小(0-1之间)
        """
        # 数据预处理
        if not isinstance(data, pd.DataFrame):
            raise ValueError("data 必须是 Pandas DataFrame")
        if x_col not in data.columns or y_col not in data.columns:
            raise ValueError(f"{x_col} 或 {y_col} 不在数据集中")
        if not np.all(np.isin(data[y_col].unique(), [0, 1])):
            raise ValueError(f"{y_col} 列必须是二分类变量 (0/1)")
        
        # 提取数据
        X = data[[x_col]].values.ravel()  # 确保是一维数组
        y = data[y_col].values
        
        # 拟合逻辑回归模型
        logit_model = LogisticRegression()
        logit_model.fit(X.reshape(-1, 1), y)
        
        # 计算预测概率和logit转换
        pred_prob = logit_model.predict_proba(X.reshape(-1, 1))[:, 1]
        logit_p = np.log(pred_prob / (1 - pred_prob))
        
        # 创建图形
        plt.figure(figsize=(10, 6))
        
        # 抖动处理:避免数据点重叠
        jittered_x = X + np.random.normal(loc=0, scale=0.02, size=X.shape)  # 小幅度随机抖动
        
        # 绘制原始数据点
        plt.scatter(jittered_x, logit_p, alpha=0.7, color='blue', label='Data points', s=10)
        
        # 计算并绘制LOWESS平滑曲线
        lowess_curve = lowess(logit_p, X, frac=frac)
        plt.plot(lowess_curve[:, 0], lowess_curve[:, 1], 
                 color='red', linewidth=2, label='LOWESS Smoothing')
        
        # 绘制线性拟合线
        coef = logit_model.coef_[0][0]
        intercept = logit_model.intercept_[0]
        x_vals = np.array([X.min(), X.max()])
        y_vals = intercept + coef * x_vals
        plt.plot(x_vals, y_vals, 'g--', linewidth=2, label='Linear Fit')
        
        # 添加图形元素
        plt.xlabel(f'Independent Variable: {x_col}', fontsize=12)
        plt.ylabel('Logit(P)', fontsize=12)
        plt.title(f'Linearity Check: {x_col} vs. Logit(P)', fontsize=14)
        plt.legend(fontsize=10)
        plt.grid(True, alpha=0.3)
        
        plt.show()
    
    # 示例使用
    if __name__ == "__main__":
        # 生成示例数据
        np.random.seed(42)
        n_samples = 500
        data = pd.DataFrame({
            'age': np.random.normal(45, 15, n_samples),
            'disease': np.random.binomial(1, 0.5, n_samples)
        })
        # 确保年龄为正数
        data['age'] = data['age'].clip(lower=18, upper=90)
        
        # 调用函数绘制图形
        plot_logit_linearity(data, 'age', 'disease', frac=0.3)
    

    改进点总结

    1. 数据预处理

      • 添加了输入数据的校验,确保 x_coly_col 存在于数据集中,并且 y_col 是二分类变量。
      • 使用 .ravel()X 转换为一维数组,避免潜在的维度问题。
    2. 抖动处理

      • 使用 np.random.normal 对 (X) 进行小幅随机抖动,避免数据点重叠。
    3. 代码可读性

      • 增加了注释,解释每一步的目的。
      • 使用更具描述性的变量名,例如 jittered_x
    4. 错误处理

      • 如果输入数据不符合要求,抛出明确的错误信息。

    运行结果

    运行修改后的代码后,会生成一个包含以下内容的图表:

    • 蓝色散点:原始数据点(经过抖动处理)。
    • 红色曲线:LOWESS 平滑曲线。
    • 绿色虚线:逻辑回归模型的线性拟合线。

    通过观察红色曲线与绿色虚线的吻合程度,可以判断 (X) 和 (\text{logit}(P)) 是否满足线性关系。

    评论

报告相同问题?

问题事件

  • 创建了问题 5月29日