return bool(1) 2023-02-07 10:47 采纳率: 85.7%
浏览 28
已结题

HalvingGridSearchCV迭代次数问题

HalvingGridSearchCV设置factor=1.5,min_resources=500,数据集大小为1400,参数空间大小为25,感觉理论上可以迭代3次,但实际输出看起来只迭代了2次是为什么

500*1.5**2=1125
500*1.5**1=750

感觉应该还有一次n_resources为1125的迭代

后面试了试,限制数据集大小为1000,其余不变,也迭代了两次,这个感觉是合理的

1400个样本

from sklearn.ensemble import RandomForestRegressor
from sklearn.experimental import enable_halving_search_cv
from sklearn.model_selection import HalvingGridSearchCV,KFold,GridSearchCV,cross_validate
import numpy as np

param_grid_simple = {'n_estimators': [*range(50,100,10)]
                     , 'max_depth': [*range(15,25,2)]
                    }

reg = RandomForestRegressor(random_state=110,n_jobs=8,verbose=True)
cv = KFold(random_state=110,shuffle=True)

search = HalvingGridSearchCV(estimator=reg
                            ,param_grid=param_grid_simple
                            ,factor=1.5
                            ,min_resources=500
                            ,verbose = True
                            ,random_state=110
                            ,cv = cv
                            ,n_jobs=8)

search.fit(X[:1400,:],y[:1400])

输出

n_iterations: 2
n_required_iterations: 8
n_possible_iterations: 2
min_resources_: 500
max_resources_: 1400
aggressive_elimination: False
factor: 1.5
----------
iter: 0
n_candidates: 25
n_resources: 500
Fitting 5 folds for each of 25 candidates, totalling 125 fits
----------
iter: 1
n_candidates: 17
n_resources: 750
Fitting 5 folds for each of 17 candidates, totalling 85 fits
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done  60 out of  60 | elapsed:    0.1s finished
HalvingGridSearchCV(cv=KFold(n_splits=5, random_state=110, shuffle=True),
                    estimator=RandomForestRegressor(n_jobs=8, random_state=110,
                                                    verbose=True),
                    factor=1.5, min_resources=500, n_jobs=8,
                    param_grid={'max_depth': [15, 17, 19, 21, 23],
                                'n_estimators': [50, 60, 70, 80, 90]},
                    random_state=110, verbose=True)

1000个样本

search.fit(X[:1000,:],y[:1000])

输出

n_iterations: 2
n_required_iterations: 8
n_possible_iterations: 2
min_resources_: 500
max_resources_: 1000
aggressive_elimination: False
factor: 1.5
----------
iter: 0
n_candidates: 25
n_resources: 500
Fitting 5 folds for each of 25 candidates, totalling 125 fits
----------
iter: 1
n_candidates: 17
n_resources: 750
Fitting 5 folds for each of 17 candidates, totalling 85 fits
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:    0.1s finished
HalvingGridSearchCV(cv=KFold(n_splits=5, random_state=110, shuffle=True),
                    estimator=RandomForestRegressor(n_jobs=8, random_state=110,
                                                    verbose=True),
                    factor=1.5, min_resources=500, n_jobs=8,
                    param_grid={'max_depth': [15, 17, 19, 21, 23],
                                'n_estimators': [50, 60, 70, 80, 90]},
                    random_state=110, verbose=True)
  • 写回答

1条回答 默认 最新

  • week@eight 2023-02-07 11:04
    关注

    该回答引用ChatGPT
    HalvingGridSearchCV 只迭代了 2 次的原因可能如下
    1、实际资源分配比预期少,导致没有达到最低资源限制,因此无法进行第三次迭代。
    2、参数空间大小 25 可能已经在第二次迭代中覆盖了所有可能的参数值,不再需要进行第三次迭代。
    3、可能已经找到了一个最优解,不再需要进一步的搜索,因此不再进行第三次迭代。
    通过输出可以看到,实际的迭代次数为 2,其中的第一次迭代的样本数为 500,第二次迭代的样本数为 750。这是因为 aggressive_elimination 默认为 False,所以 HalvingGridSearchCV 算法不进行激进的消除,而是保证每次迭代的样本数量在因子 factor 和 min_resources 之间取值。因此,可以看到第二次迭代的样本数为 min_resources * factor,这是比较保守的选择。
    另外,通过 n_required_iterations 和 n_possible_iterations 的值可以发现,理论上的迭代次数为 8 次,但实际的迭代次数为 2 次。原因是数据集的大小 1400 较小,因此不需要进行多次迭代。

    评论 编辑记录

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 2月14日
  • 修改了问题 2月7日
  • 修改了问题 2月7日
  • 创建了问题 2月7日

悬赏问题

  • ¥15 网络分析设施点无法识别
  • ¥15 状态图的并发态问题咨询
  • ¥15 PFC3D,plot
  • ¥15 VAE模型编程报错无法解决
  • ¥100 基于SVM的信息粒化时序回归预测,有偿求解!
  • ¥15 物体组批优化问题-数学建模求解答
  • ¥15 微信原生小程序tabBar编译报错
  • ¥350 麦克风声源定位坐标不准
  • ¥15 apifox与swagger使用
  • ¥15 egg异步请求返回404的问题