2021-01-08 11:59

create_supervised_trainer fails if model device is diff from arg device

Feature or bug to discuss. If I modify mnist.py example:

def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
    train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
    model = Net()
    device = 'cpu'

    if torch.cuda.is_available():
        device = 'cuda'
#         model = model.to(device)

    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={'accuracy': CategoricalAccuracy(),
                                                     'nll': Loss(F.nll_loss)},

The operation model(x) will fail as x and model are on different devices. Maybe we need to : - check this and assert or - remove device from argument and setup it internally according to the model or - set model on the provided device

What do you guys think about ?


  • 点赞
  • 写回答
  • 关注问题
  • 收藏
  • 复制链接分享
  • 邀请回答


  • weixin_40008969 weixin_40008969 4月前

    I don't think this is a huge issue, and I don't think we should take care of setting the correct device for the user.

    One thing we could perhaps do, is adding device='auto', which will set device to the device of the model parameters, if all parameters live on the same device.

    点赞 评论 复制链接分享
  • weixin_39621235 weixin_39621235 4月前

    I agree that this is not a big deal. Just if someone applies the document as it is written:

    device (optional): device type specification (default: None)

    he/she would wonder why it is applied on batches only...

    点赞 评论 复制链接分享
  • weixin_40008969 weixin_40008969 4月前

    Oh, yeah, on a second thought we could actually go with

    • set model on the provided device

    or rename the arg / clarify the docs.

    点赞 评论 复制链接分享
  • weixin_39636540 weixin_39636540 4月前

    Yeah I think we should set the model on the provided device, this would fix the bug and make sense as we should make it as easy as possible with these factory functions

    点赞 评论 复制链接分享