weixin_4412 2023-12-09 17:18 采纳率: 0%
浏览 12

大模型Expected all tensors to be on the same device

大模型小bai,四张卡推理,加载模型时


```python
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=load_type, device_map="auto").half().cuda()
指定了device_map="auto",模型被划分到四块卡上,看着好像没有问题

```bash
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.06              Driver Version: 545.23.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A800 80GB PCIe          Off | 00000000:4F:00.0 Off |                    0 |
| N/A   38C    P0              66W / 300W |  65255MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A800 80GB PCIe          Off | 00000000:50:00.0 Off |                    0 |
| N/A   39C    P0              66W / 300W |  17787MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA A800 80GB PCIe          Off | 00000000:53:00.0 Off |                    0 |
| N/A   40C    P0              68W / 300W |  17787MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA A800 80GB PCIe          Off | 00000000:57:00.0 Off |                    0 |
| N/A   39C    P0              69W / 300W |  14287MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

但是当调用推理后,报错:

ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/site-packages/uvicorn/protocols/http/h11_impl.py", line 429, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
  File "/usr/local/lib/python3.8/site-packages/uvicorn/middleware/proxy_headers.py", line 78, in __call__
    return await self.app(scope, receive, send)
  File "/usr/local/lib/python3.8/site-packages/fastapi/applications.py", line 276, in __call__
    await super().__call__(scope, receive, send)
  File "/usr/local/lib/python3.8/site-packages/starlette/applications.py", line 122, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.8/site-packages/starlette/middleware/errors.py", line 184, in __call__
    raise exc
  File "/usr/local/lib/python3.8/site-packages/starlette/middleware/errors.py", line 162, in __call__
    await self.app(scope, receive, _send)
  File "/usr/local/lib/python3.8/site-packages/starlette/middleware/cors.py", line 84, in __call__
    await self.app(scope, receive, send)
  File "/usr/local/lib/python3.8/site-packages/starlette/middleware/exceptions.py", line 79, in __call__
    raise exc
  File "/usr/local/lib/python3.8/site-packages/starlette/middleware/exceptions.py", line 68, in __call__
    await self.app(scope, receive, sender)
  File "/usr/local/lib/python3.8/site-packages/fastapi/middleware/asyncexitstack.py", line 21, in __call__
    raise e
  File "/usr/local/lib/python3.8/site-packages/fastapi/middleware/asyncexitstack.py", line 18, in __call__
    await self.app(scope, receive, send)
  File "/usr/local/lib/python3.8/site-packages/starlette/routing.py", line 718, in __call__
    await route.handle(scope, receive, send)
  File "/usr/local/lib/python3.8/site-packages/starlette/routing.py", line 276, in handle
    await self.app(scope, receive, send)
  File "/usr/local/lib/python3.8/site-packages/starlette/routing.py", line 66, in app
    response = await func(request)
  File "/usr/local/lib/python3.8/site-packages/fastapi/routing.py", line 237, in app
    raw_response = await run_endpoint_function(
  File "/usr/local/lib/python3.8/site-packages/fastapi/routing.py", line 163, in run_endpoint_function
    return await dependant.call(**values)
  File "openai_api_codellama-34b-2.py", line 194, in create_chat_completion
    generation_output = model.generate(
  File "/usr/local/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/transformers-4.33.0-py3.8.egg/transformers/generation/utils.py", line 1681, in generate
    return self.beam_search(
  File "/usr/local/lib/python3.8/site-packages/transformers-4.33.0-py3.8.egg/transformers/generation/utils.py", line 3020, in beam_search
    outputs = self(
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/transformers-4.33.0-py3.8.egg/transformers/models/llama/modeling_llama.py", line 820, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/transformers-4.33.0-py3.8.egg/transformers/models/llama/modeling_llama.py", line 708, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/transformers-4.33.0-py3.8.egg/transformers/models/llama/modeling_llama.py", line 421, in forward
    hidden_states = self.input_layernorm(hidden_states)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/transformers-4.33.0-py3.8.egg/transformers/models/llama/modeling_llama.py", line 89, in forward
    return self.weight * hidden_states.to(input_dtype)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

我明白这个报错是说模型和数据不在同一张卡上,但是device_map=auto本来是指自动切分模型(?),从显存占用量来看每块卡的使用率也是不同的,所以我不明白 本来就是要将模型分散到不同卡上,可是现在看起来不合并又不行?请问这要怎么改呢?
以及,torchrun的方式是不是就是把模型复制到四张卡上一模一样,然后就不会报这个错?

  • 写回答

2条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2023-12-09 20:31
    关注

    【以下回答由 GPT 生成】

    问题:能否添加更多有关您代码的信息,特别是在处理模型加载方面的问题时,例如您想讨论的错误消息,以便更好地理解上下文。此外,您是否尝试过在加载模型时手动指定GPU设备的参数?

    回答:非常抱歉,您提供的问题比较模糊,没有明确的代码和错误消息。为了更好地理解上下文,请提供相关代码和错误消息,以便我可以帮助您解决问题。

    关于手动指定GPU设备的参数,在许多深度学习框架中,例如TensorFlow和PyTorch,我们可以通过指定设备的索引或名称来将模型加载到特定的GPU上。下面是使用PyTorch加载模型并指定GPU设备的示例代码:

    import torch
    
    # 检查是否有可用的GPU设备
    if torch.cuda.is_available():
        device = torch.device("cuda")  # 使用默认的GPU设备
        # 或者可以指定特定的GPU设备
        # device = torch.device("cuda:0")  # 使用索引为0的GPU设备
        # device = torch.device("cuda:1")  # 使用索引为1的GPU设备
    
        # 加载模型到指定的GPU设备
        model = Model().to(device)
    
        # 进行训练或推理等操作
        output = model(input_tensor.to(device))
    else:
        device = torch.device("cpu")  # 如果没有可用的GPU设备,则使用CPU设备
    
        model = Model().to(device)
    
        output = model(input_tensor)
    

    请提供更多详细信息,以便我能够更好地帮助您解决问题。


    如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^
    评论

报告相同问题?

问题事件

  • 创建了问题 12月9日

悬赏问题

  • ¥15 Android studio AVD启动不了
  • ¥15 陆空双模式无人机怎么做
  • ¥15 想咨询点问题,与算法转换,负荷预测,数字孪生有关
  • ¥15 C#中的编译平台的区别影响
  • ¥15 软件供应链安全是跟可靠性有关还是跟安全性有关?
  • ¥15 电脑蓝屏logfilessrtsrttrail问题
  • ¥20 关于wordpress建站遇到的问题!(语言-php)(相关搜索:云服务器)
  • ¥15 【求职】怎么找到一个周围人素质都很高不会欺负他人,并且未来月薪能够达到一万以上(技术岗)的工作?希望可以收到写有具体,可靠,已经实践过了的路径的回答?
  • ¥15 Java+vue部署版本反编译
  • ¥100 对反编译和ai熟悉的开发者。