Ampare1987 2021-11-09 09:59 采纳率: 55.6%
浏览 13
已结题

机器学习建立训练集合测试集的函数

我现在在阅读 Hands-On Machine Learning with Scikit-Learn & TensorFlow 书中在建立测试集中给出了如下代码:

def test_set_check(identifier, test_ratio, hash):
    return hash(np.int64(identifier)).digest()[-1] < 256 * test_ratio


def split_train_test_by_id(data, test_ratio, id_column, hash=hashlib.md5):
    ids = data[id_column]
    in_test_set = ids.apply(lambda id_: test_set_check(id_, test_ratio, hash))
    return data.loc[~in_test_set], data.loc[in_test_set]


housing_with_id = housing.reset_index()  # adds an 'index' column
train_set, test_set = split_train_test_by_id(housing_with_id, 0.2, "index")
print(len(train_set), "train+", len(test_set), "test")

我个人的感觉这段代码的目的是为了让读者理解是如何得到训练集合测试集的,实际工作中应该不用自行输入这一段代码。
如果我的感觉是正确的实际的工作中是如何做的呢?

  • 写回答

1条回答 默认 最新

  • Julian_0115 2021-11-09 12:20
    关注
    from sklearn.model_selection import train_test_split
    train_set, test_set = train_test_split(housing, test_size=0.2, random_state=42)
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 系统已结题 11月17日
  • 已采纳回答 11月9日
  • 修改了问题 11月9日
  • 创建了问题 11月9日

悬赏问题

  • ¥15 Qt下使用tcp获取数据的详细操作
  • ¥15 idea右下角设置编码是灰色的
  • ¥15 全志H618ROM新增分区
  • ¥15 在grasshopper里DrawViewportWires更改预览后,禁用电池仍然显示
  • ¥15 NAO机器人的录音程序保存问题
  • ¥15 C#读写EXCEL文件,不同编译
  • ¥15 MapReduce结果输出到HBase,一直连接不上MySQL
  • ¥15 扩散模型sd.webui使用时报错“Nonetype”
  • ¥15 stm32流水灯+呼吸灯+外部中断按键
  • ¥15 将二维数组,按照假设的规定,如0/1/0 == "4",把对应列位置写成一个字符并打印输出该字符