Lukas00990 2022-10-25 17:17 采纳率: 40.8%
浏览 59
已结题

python程序问题,传递参数

这个代码,为什么定义了 task_type = ‘multiclass', 但是最后一行打印出来的score 还是以
这个得出来的呢?这个分数是 task_type =’regression‘时得出的分数
else:
assert task_type == 'regression'
score = sklearn.metrics.mean_squared_error(target, prediction) ** 0.5 * y_std
return score

img

下面是源代码


task_type = 'multiclass'

@torch.no_grad()
def evaluate(part, task_type):
    model.eval()
    prediction = []
    for batch in delu.iter_batches(X[part], 1024):
        prediction.append(apply_model(batch))
    prediction = torch.cat(prediction).squeeze(1).cpu().numpy()
    target = y[part].cpu().numpy()

    if task_type == 'binclass':
        prediction = np.round(scipy.special.expit(prediction))
        score = sklearn.metrics.accuracy_score(target, prediction)
    elif task_type == 'multiclass':
        prediction = prediction.argmax(1)
        score = sklearn.metrics.accuracy_score(target, prediction)
    else:
        assert task_type == 'regression'
        score = sklearn.metrics.mean_squared_error(target, prediction) ** 0.5 * y_std
    return score


# Create a dataloader for batches of indices
# Docs: https://yura52.github.io/zero/reference/api/zero.data.IndexLoader.html
batch_size = 256
train_loader = delu.data.IndexLoader(len(X['train']), batch_size, device=device)

# Create a progress tracker for early stopping
# Docs: https://yura52.github.io/zero/reference/api/zero.ProgressTracker.html
progress = delu.ProgressTracker(patience=100)
print(f'Test score before training: {evaluate("test", task_type):.4f}')

  • 写回答

3条回答 默认 最新

  • 快乐鹦鹉 2022-10-25 17:32
    关注

    你在13行前面输出一下task_type看看值是什么啊。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

问题事件

  • 系统已结题 11月2日
  • 已采纳回答 10月25日
  • 修改了问题 10月25日
  • 创建了问题 10月25日

悬赏问题

  • ¥15 装 pytorch 的时候出了好多问题,遇到这种情况怎么处理?
  • ¥20 IOS游览器某宝手机网页版自动立即购买JavaScript脚本
  • ¥15 手机接入宽带网线,如何释放宽带全部速度
  • ¥30 关于#r语言#的问题:如何对R语言中mfgarch包中构建的garch-midas模型进行样本内长期波动率预测和样本外长期波动率预测
  • ¥15 ETLCloud 处理json多层级问题
  • ¥15 matlab中使用gurobi时报错
  • ¥15 这个主板怎么能扩出一两个sata口
  • ¥15 不是,这到底错哪儿了😭
  • ¥15 2020长安杯与连接网探
  • ¥15 关于#matlab#的问题:在模糊控制器中选出线路信息,在simulink中根据线路信息生成速度时间目标曲线(初速度为20m/s,15秒后减为0的速度时间图像)我想问线路信息是什么