Lukas00990 2022-10-25 17:17 采纳率: 40.8%
浏览 59
已结题

python程序问题,传递参数

这个代码,为什么定义了 task_type = ‘multiclass', 但是最后一行打印出来的score 还是以
这个得出来的呢?这个分数是 task_type =’regression‘时得出的分数
else:
assert task_type == 'regression'
score = sklearn.metrics.mean_squared_error(target, prediction) ** 0.5 * y_std
return score

img

下面是源代码


task_type = 'multiclass'

@torch.no_grad()
def evaluate(part, task_type):
    model.eval()
    prediction = []
    for batch in delu.iter_batches(X[part], 1024):
        prediction.append(apply_model(batch))
    prediction = torch.cat(prediction).squeeze(1).cpu().numpy()
    target = y[part].cpu().numpy()

    if task_type == 'binclass':
        prediction = np.round(scipy.special.expit(prediction))
        score = sklearn.metrics.accuracy_score(target, prediction)
    elif task_type == 'multiclass':
        prediction = prediction.argmax(1)
        score = sklearn.metrics.accuracy_score(target, prediction)
    else:
        assert task_type == 'regression'
        score = sklearn.metrics.mean_squared_error(target, prediction) ** 0.5 * y_std
    return score


# Create a dataloader for batches of indices
# Docs: https://yura52.github.io/zero/reference/api/zero.data.IndexLoader.html
batch_size = 256
train_loader = delu.data.IndexLoader(len(X['train']), batch_size, device=device)

# Create a progress tracker for early stopping
# Docs: https://yura52.github.io/zero/reference/api/zero.ProgressTracker.html
progress = delu.ProgressTracker(patience=100)
print(f'Test score before training: {evaluate("test", task_type):.4f}')

  • 写回答

3条回答 默认 最新

  • 快乐鹦鹉 2022-10-25 17:32
    关注

    你在13行前面输出一下task_type看看值是什么啊。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

问题事件

  • 系统已结题 11月2日
  • 已采纳回答 10月25日
  • 修改了问题 10月25日
  • 创建了问题 10月25日

悬赏问题

  • ¥20 西门子S7-Graph,S7-300,梯形图
  • ¥50 用易语言http 访问不了网页
  • ¥50 safari浏览器fetch提交数据后数据丢失问题
  • ¥15 matlab不知道怎么改,求解答!!
  • ¥15 永磁直线电机的电流环pi调不出来
  • ¥15 用stata实现聚类的代码
  • ¥15 请问paddlehub能支持移动端开发吗?在Android studio上该如何部署?
  • ¥20 docker里部署springboot项目,访问不到扬声器
  • ¥15 netty整合springboot之后自动重连失效
  • ¥15 悬赏!微信开发者工具报错,求帮改