# clf = DecisionTreeClassifier() score:0.7836742214851667
classes = [' <=50K', ' >50K']
clf = DecisionTreeClassifier(criterion = 'entropy', max_depth = 8, min_samples_split = 5)
clf = clf.fit(train_dataset, train_target)
pred = clf.predict(test_dataset)
print(pred)
score = clf.score(test_dataset, test_target)
# pred = clf.predict_proba(test_dataset)
print(score)
# print(np.argmax(pred, axis = 1))
with open('Predict/DecisionTree.csv', 'w', newline = '') as file :
writer = csv.writer(file)
writer.writerow(['id', 'result_pred'])
for i, result in enumerate(pred) :
writer.writerow([i, classes[result]])
ValueError Traceback (most recent call last)
C:\Users\151354~1\AppData\Local\Temp/ipykernel_7032/68014145.py in <module>
3 clf = DecisionTreeClassifier(criterion = 'entropy', max_depth = 8, min_samples_split = 5)
4 clf = clf.fit(train_dataset, train_target)
----> 5 pred = clf.predict(test_dataset)
6 print(pred)
7 score = clf.score(test_dataset, test_target)
D:\anaconda\lib\site-packages\sklearn\tree\_classes.py in predict(self, X, check_input)
440 """
441 check_is_fitted(self)
--> 442 X = self._validate_X_predict(X, check_input)
443 proba = self.tree_.predict(X)
444 n_samples = X.shape[0]
D:\anaconda\lib\site-packages\sklearn\tree\_classes.py in _validate_X_predict(self, X, check_input)
405 """Validate the training data on predict (probabilities)."""
406 if check_input:
--> 407 X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr",
408 reset=False)
409 if issparse(X) and (X.indices.dtype != np.intc or
D:\anaconda\lib\site-packages\sklearn\base.py in _validate_data(self, X, y, reset, validate_separately, **check_params)
419 out = X
420 elif isinstance(y, str) and y == 'no_validation':
--> 421 X = check_array(X, **check_params)
422 out = X
423 else:
D:\anaconda\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
61 extra_args = len(args) - len(all_args)
62 if extra_args <= 0:
---> 63 return f(*args, **kwargs)
64
65 # extra_args > 0
D:\anaconda\lib\site-packages\sklearn\utils\validation.py in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator)
685 # If input is scalar raise error
686 if array.ndim == 0:
--> 687 raise ValueError(
688 "Expected 2D array, got scalar array instead:\narray={}.\n"
689 "Reshape your data either using array.reshape(-1, 1) if "
ValueError: Expected 2D array, got scalar array instead:
array=nan.
Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.
可以远程的