这个代码,为什么定义了 task_type = ‘multiclass', 但是最后一行打印出来的score 还是以
这个得出来的呢?这个分数是 task_type =’regression‘时得出的分数
else:
assert task_type == 'regression'
score = sklearn.metrics.mean_squared_error(target, prediction) ** 0.5 * y_std
return score
下面是源代码
task_type = 'multiclass'
@torch.no_grad()
def evaluate(part, task_type):
model.eval()
prediction = []
for batch in delu.iter_batches(X[part], 1024):
prediction.append(apply_model(batch))
prediction = torch.cat(prediction).squeeze(1).cpu().numpy()
target = y[part].cpu().numpy()
if task_type == 'binclass':
prediction = np.round(scipy.special.expit(prediction))
score = sklearn.metrics.accuracy_score(target, prediction)
elif task_type == 'multiclass':
prediction = prediction.argmax(1)
score = sklearn.metrics.accuracy_score(target, prediction)
else:
assert task_type == 'regression'
score = sklearn.metrics.mean_squared_error(target, prediction) ** 0.5 * y_std
return score
# Create a dataloader for batches of indices
# Docs: https://yura52.github.io/zero/reference/api/zero.data.IndexLoader.html
batch_size = 256
train_loader = delu.data.IndexLoader(len(X['train']), batch_size, device=device)
# Create a progress tracker for early stopping
# Docs: https://yura52.github.io/zero/reference/api/zero.ProgressTracker.html
progress = delu.ProgressTracker(patience=100)
print(f'Test score before training: {evaluate("test", task_type):.4f}')