DeerJM 2023-05-26 11:10 采纳率: 33.3%
浏览 14

吴恩达机器学习W3_Lab06_Gradient_Descent

在如题的optional lab中,有如下代码计算决策边界,但是按照w0x0+w1x1+b=0求解x0和x1的话,w_out的索引不是应该与x的下标对应吗?

fig,ax = plt.subplots(1,1,figsize=(5,4))
# plot the probability 
plt_prob(ax, w_out, b_out)

# Plot the original data
ax.set_ylabel(r'$x_1$')
ax.set_xlabel(r'$x_0$')   
ax.axis([0, 4, 0, 3.5])
plot_data(X_train,y_train,ax)

# Plot the decision boundary
x0 = -b_out/w_out[1] #为什么不是w_out[0]
x1 = -b_out/w_out[0] #为什么不是w_out[1]
ax.plot([0,x0],[x1,0], c=dlc["dlblue"], lw=1)
plt.show()
  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2023-05-26 20:19
    关注
    不知道你这个问题是否已经解决, 如果还没有解决的话:
    • 以下回答来自chatgpt:

      首先,需要明确在W3_Lab06_Gradient_Descent中,计算决策边界的代码是哪一段。根据题目描述,应该是在第二部分中的“Visualizing the decision boundary”部分。该部分的代码如下:

      w_out = sess.run(w)
      b_out = sess.run(b)
      
      x1_vals = [d[1] for d in data]
      plot_x = np.array([np.min(x1_vals) - 2, np.max(x1_vals) + 2])
      plot_y = 1 / w_out[1] * (-w_out[0] * plot_x - b_out)
      plot_y = [float(p) for p in plot_y]
      
      setosa_x = [d[1] for i,d in enumerate(data) if target[i]==0]
      setosa_y = [d[0] for i,d in enumerate(data) if target[i]==0]
      versicolor_x = [d[1] for i,d in enumerate(data) if target[i]==1]
      versicolor_y = [d[0] for i,d in enumerate(data) if target[i]==1]
      boundary_x = plot_x
      boundary_y = plot_y
      
      mpl.rcParams['legend.fontsize'] = 10
      fig = plt.figure()
      ax = fig.gca()
      ax.plot(setosa_x, setosa_y, 'o', label='I.setosa')
      ax.plot(versicolor_x, versicolor_y, 'x', label='I.versicolor')
      ax.plot(boundary_x, boundary_y, '-', label='Decision Boundary')
      ax.legend()
      ax.set_xlabel('Petal length')
      ax.set_ylabel('Sepal width')
      ax.set_title('Logistic Regression Example')
      plt.show()
      

      接下来,需要明确w_out和x的索引不对应是什么问题。根据代码,在计算决策边界时,我们需要求解w0x0+w1x1+b=0的x0和x1。这需要用到w_out和b_out,它们是在训练模型时得到的w和b的最终值。然而,作者在计算决策边界时,却使用了1/w_out[1]来计算x0和x1,这会导致x0和x1与w_out和b_out所代表的含义不一致,进而让代码难以理解。

      为了解决这个问题,我们需要重新推导w_out和b_out的含义,并正确地使用它们来计算决策边界的x0和x1。首先,我们回顾一下逻辑回归的决策边界的含义:当wx+b=0时,预测输出为1/2。这意味着w和b在决策边界上的取值必须满足wx+b=0。考虑到决策边界是一个直线,其中w的系数和b的系数可以代表该直线的斜率和截距。

      假设决策边界的直线方程是y=kx+b,其中k是直线的斜率,b是截距。我们知道,如果将决策边界画在x1-x2平面上,则有x1和x2代表了两个坐标轴,可以分别表示直线方程中的自变量x和因变量y。当x1和x2满足方程y=kx+b时,这些坐标点就位于决策边界上。因此,我们可以用x1和x2代替y和x,将直线方程y=kx+b变换为x2=kx1+b。进而,我们可以将逻辑回归模型中wx+b=0的公式变为w0x0+w1x1+b=0的公式,只需要将x0设为1,x1设为x2即可。这里的w0和w1则分别代表了直线方程中的截距b和斜率k。

      现在我们知道了如何得到决策边界在x1-x2平面上的方程,也明确了w_out中的w0和w1分别代表什么含义。因此,在计算决策边界的x0和x1时,我们应该写成:

      plot_y = - (b_out / w_out[1]) - (w_out[0] / w_out[1]) * plot_x
      

      至此,我们成功解决了w_out和x的索引不对应的问题。完整的Visualizing the decision boundary部分的代码为:

      w_out = sess.run(w)
      b_out = sess.run(b)
      
      x1_vals = [d[1] for d in data]
      plot_x = np.array([np.min(x1_vals) - 2, np.max(x1_vals) + 2])
      plot_y = - (b_out / w_out[1]) - (w_out[0] / w_out[1]) * plot_x
      
      setosa_x = [d[1] for i,d in enumerate(data) if target[i]==0]
      setosa_y = [d[0] for i,d in enumerate(data) if target[i]==0]
      versicolor_x = [d[1] for i,d in enumerate(data) if target[i]==1]
      versicolor_y = [d[0] for i,d in enumerate(data) if target[i]==1]
      boundary_x = plot_x
      boundary_y = plot_y
      
      mpl.rcParams['legend.fontsize'] = 10
      fig = plt.figure()
      ax = fig.gca()
      ax.plot(setosa_x, setosa_y, 'o', label='I.setosa')
      ax.plot(versicolor_x, versicolor_y, 'x', label='I.versicolor')
      ax.plot(boundary_x, boundary_y, '-', label='Decision Boundary')
      ax.legend()
      ax.set_xlabel('Petal length')
      ax.set_ylabel('Sepal width')
      ax.set_title('Logistic Regression Example')
      plt.show()
      

    如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^
    评论

报告相同问题?

问题事件

  • 创建了问题 5月26日

悬赏问题

  • ¥15 latex投稿显示click download
  • ¥15 请问读取环境变量文件失败是什么原因?
  • ¥15 在若依框架下实现人脸识别
  • ¥15 网络科学导论,网络控制
  • ¥100 安卓tv程序连接SQLSERVER2008问题
  • ¥15 利用Sentinel-2和Landsat8做一个水库的长时序NDVI的对比,为什么Snetinel-2计算的结果最小值特别小,而Lansat8就很平均
  • ¥15 metadata提取的PDF元数据,如何转换为一个Excel
  • ¥15 关于arduino编程toCharArray()函数的使用
  • ¥100 vc++混合CEF采用CLR方式编译报错
  • ¥15 coze 的插件输入飞书多维表格 app_token 后一直显示错误,如何解决?