jijkck 2023-11-13 00:30 采纳率: 66.7%
浏览 5

线性SVM算法在blobs数据集的三个二分类器的分类直线的可视化

这个是一本书上的线性SVM算法在blobs数据集的三个二分类器的分类直线的可视化代码,可以看到在画直线时y=-(x*w[0]+b)/w[1],这是为什么,线性SVM的分类置信方程不是y=w[0]*x[0]+w[1]*x[1]+b吗?求解答

import mglearn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.svm import LinearSVC
from sklearn.datasets import make_blobs

if __name__ == '__main__':
    X, y = make_blobs(random_state=42)
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

    linear_svm = LinearSVC().fit(X_train, y_train)

    '''画数据点'''
    mglearn.discrete_scatter(X[:, 0], X[:, 1], y)

    '''画三个二分类器的直线'''
    line = np.linspace(-15, 15)
    for coef, intercept, color in zip(linear_svm.coef_, linear_svm.intercept_, ['b', 'r', 'g']):
        '''为什么y=-(x*w[0]+b)/w[1],分类置信方程不是y=w[0]*x[0]+w[1]*x[1]+b吗??'''
        plt.plot(line, -(line*coef[0]+intercept)/coef[1], c=color)
    
    '''绘图设置'''
    plt.ylim(-10, 15)
    plt.xlim(-10, 8)
    plt.xlabel("Feature 0")
    plt.ylabel("Feature 1")
    plt.legend(['Class 0', 'Class 1', 'Class 2', 'Lineclass 0', 'Lineclass 1', 'Lineclass 2'], loc=(1.01, 0.3))
    plt.show()

img

  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2023-11-13 02:13
    关注

    【以下回答由 GPT 生成】

    问题的根本在于,我们的SVM分类器是一个二维平面上的线性分类器,即将输入的特征空间映射到一个二维平面上。因此,在这个简化的情况下,我们只需要考虑两个特征,即x[0]和x[1]。而且,我们的目标是找到一个划分两个类别的直线。

    假设我们有一个向量w,表示我们在特征空间中找到的法向量。这个向量固定地与直线平行,而法向量的方向是决定直线在平面上方向的关键因素。我们还有一个常数b,用于偏置,决定了直线与原点的间距。

    当我们要在二维平面上画一条直线时,我们需要知道直线上的两个点,以便确定直线。为了找到这两个点,我们需要知道直线与x轴和y轴的交点。

    考虑到我们要将直线表示为y=mx+c的形式,其中m是斜率,c是截距,我们可以通过将x设置为0,求解y来找到直线与y轴的交点。同样地,我们可以通过将y设置为0,求解x来找到直线与x轴的交点。

    给定直线方程y=-(x*w[0]+b)/w[1],我们可以设置x=0并解y,这样我们就得到了直线与y轴的交点。类似地,我们可以设置y=0并解x,这样我们就得到了直线与x轴的交点。

    在代码实现中,我们可以通过以下方式画出分类直线:

    import numpy as np
    import matplotlib.pyplot as plt
    
    def plot_line(w, b):
        # 生成x轴上的点
        x = np.linspace(-1, 1, 100)
    
        # 计算对应的y轴坐标
        y = -(x*w[0] + b)/w[1]
    
        # 绘制直线
        plt.plot(x, y, '-r', label='Decision Boundary')
        plt.xlabel('x1')
        plt.ylabel('x2')
        plt.legend(loc='best')
    
    # 示例数据
    w = np.array([0.5, 0.5]) # SVM的权重向量
    b = 0.2 # SVM的偏差
    
    # 绘制分类直线
    plot_line(w, b)
    plt.show()
    

    这段代码将在-1到1之间的x轴上生成一系列点,并计算对应的y轴坐标。然后,通过绘制这些点的连线,我们就可以得到分类直线。同时也需要指定x轴和y轴的标签,并显示图例。

    希望这个解答能够帮助你理解为什么直线的公式是y=-(xw[0]+b)/w[1],而不是y=w[0]x[0]+w[1]*x[1]+b。如果还有其他问题,请随时提问。


    如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^
    评论

报告相同问题?

问题事件

  • 创建了问题 11月13日

悬赏问题

  • ¥15 悬赏Python-playwright部署在centos7上
  • ¥15 psoc creator软件有没有人能远程安装啊
  • ¥15 快速扫描算法求解Eikonal方程咨询
  • ¥20 我的是道格手机,重置后屏幕右上角出现红色字的未写入tee key 和未写入google key请问怎么去掉啊
  • ¥30 关于R语言运行分区生存模型中的问题!
  • ¥15 校内二手商品转让网站
  • ¥20 高德地图聚合图层MarkerCluster聚合多个点,但是ClusterData只有其中部分数据,原因应该是有经纬度重合的地方点,现在我想让ClusterData显示所有点的信息,如何实现?
  • ¥100 求Web版SPC控制图程序包调式
  • ¥20 指导如何跑通以下两个Github代码
  • ¥15 大家知道这个后备文件怎么删吗,为啥这些文件我只看到一份,没有后备呀