运行sklearn的GBDT时,使用.fit()函数总是出现错误,一直没出现结果,且程序无中断一直运行,手动停止后出现以下提示:
问题相关代码,请勿粘贴截图
import cv2
import numpy as np
from osgeo import gdal, gdalconst
import sys
import os
import pandas as pd
from osgeo import gdal
import matplotlib.pyplot as plt
from sklearn import ensemble
from sklearn import datasets
from sklearn.model_selection import train_test_split
from itertools import chain
from tqdm import tqdm
import time
text = ""
for char in tqdm(["a", "b"]):
text = text + char
time.sleep(0.5)
# 读取tif数据 to
def read_tif_array(path, filetype):
pathDir = os.listdir(path) # 文件放置在当前文件夹中,用来获取当前文件夹内所有文件目录
array = {}
i = 0
for x in pathDir:
index = x.rfind('.')
if x[index:] == filetype:
tif = cv2.imread(path + "/" + x, -1)
tif = list(chain(*tif))
array[x[:index]] = tif
i = i + 1
else:
i = i
array = pd.DataFrame(array).values
return array
# 读取特征值
feature = read_tif_array("D:/Personality/paper/GBDT/train", '.tif')
# print(len(feature)) # 一共有15个特征值
# 读取地质类为标签
label = read_tif_array("D:/Personality/paper/GBDT/label", '.tif')
label = label.ravel()
X, y = feature, label
labels, y = np.unique(y, return_inverse=True) # 标签
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=0) # 创建数据集
original_params = {
"n_estimators": 400,
"max_leaf_nodes": 4,
"max_depth": None,
"random_state": 2,
"min_samples_split": 5,
} # 设置树的基本参数,用于后面计算
setting = {"learning_rate": 0.2, "subsample": 1.0}
params = dict(original_params) # 转化为字典
params.update(setting) # 更新字典键值对
clf = ensemble.GradientBoostingClassifier(**params) # 梯度
blf = clf.fit(X_train, y_train)
# tif为array
def load_imgarray(img_file_path):
"""
读取栅格数据,将其转换成对应数组
:param: img_file_path: 栅格数据路径
:return: 返回投影,几何信息,和转换后的数组(5888,5888,10)
"""
dataset = gdal.Open(img_file_path) # 读取栅格数据
print('处理图像的栅格波段数总共有:', dataset.RasterCount)
# 判断是否读取到数据
if dataset is None:
print('Unable to open *.tif')
sys.exit(1) # 退出
projection = dataset.GetProjection() # 投影
transform = dataset.GetGeoTransform() # 几何信息
# 直接读取dataset
img_array = dataset.ReadAsArray()
return projection, transform, img_array
# tiff_file为tif, data_array为数组
def Write_imgarray(tiff_file, im_proj, im_geotrans, data_array):
if 'int8' in data_array.dtype.name:
datatype = gdal.GDT_Int16
elif 'int16' in data_array.dtype.name:
datatype = gdal.GDT_Int16
else:
datatype = gdal.GDT_Float32
if len(data_array.shape) == 3:
im_bands, im_height, im_width = data_array.shape
else:
im_bands, (im_height, im_width) = 1, data_array.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(tiff_file, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(data_array)
else:
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(data_array[i])
del dataset
运行结果及报错内容
我的解答思路和尝试过的方法
以为是栈溢出的问题,尝试修改大小,但无济于事。