名字不能取太长 2024-05-30 14:09 采纳率: 76.7%
浏览 19
已结题

使用MPI广播数据遇到阻塞

以下代码使用了MPI广播数据,但是我发现它存在阻塞的问题,代码如下:


def load_state_dict(path, **kwargs):
    """
    Load a PyTorch file without redundant fetches across MPI ranks.
    """
    chunk_size = 2**30  # MPI has a relatively small size limit
    if MPI.COMM_WORLD.Get_rank() == 0:
        with bf.BlobFile(path, "rb") as f:
            data = f.read()
        num_chunks = len(data) // chunk_size
        if len(data) % chunk_size:
            num_chunks += 1
        MPI.COMM_WORLD.bcast(num_chunks)
        for i in range(0, len(data), chunk_size):
            MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
    else:
        num_chunks = MPI.COMM_WORLD.bcast(None)
        data = bytes()
        for _ in range(num_chunks):
            data += MPI.COMM_WORLD.bcast(None)

    return th.load(io.BytesIO(data), **kwargs)

调试后发现阻塞位置在这个地方,在循环最后一次发生了阻塞:

for i in range(0, len(data), chunk_size):
            MPI.COMM_WORLD.bcast(data[i : i + chunk_size])

请问如何解决?为什么这里会发生阻塞?

  • 写回答

27条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2024-05-30 14:09
    关注
    获得0.40元问题酬金

    引自免费微信小程序:皆我百晓生

    为了解决这个问题,我们需要理解 MPI 的工作原理。MPI(Message Passing Interface)是一种用于并行计算的标准库,允许进程间通信。在您的代码中,所有非主进程都在等待最后一个块的数据广播完成,然后才会继续执行。这是因为 bcast 函数是同步的,这意味着发送方会一直等待直到所有接收方都接收到数据。

    要解决这个问题,我们可以稍微调整代码结构,使得非主进程在开始接收数据之前就开始处理已经接收到的块。这样可以避免在最后一轮广播时的阻塞。这里是一个修改后的版本:

    import torch as th
    from mpi4py import MPI
    
    def load_state_dict_mpichunked(path, **kwargs):
        """
        Load a PyTorch file without redundant fetches across MPI ranks using a chunk-based approach.
        """
        chunk_size = 2**30  # MPI has a relatively small size limit
        comm = MPI.COMM_WORLD
    
        if comm.Get_rank() == 0:
            with open(path, "rb") as f:
                data = f.read()
            num_chunks = len(data) // chunk_size
            if len(data) % chunk_size:
                num_chunks += 1
            comm.bcast(num_chunks)
    
            chunks_to_send = [data[i:i+chunk_size] for i in range(0, len(data), chunk_size)]
            for chunk in chunks_to_send:
                comm.bcast(chunk)
        else:
            num_chunks = comm.bcast(None)
            chunks_to_recv = [[] for _ in range(num_chunks)]
            for i in range(num_chunks):
                chunk = comm.bcast(None)
                chunks_to_recv[i].append(chunk)
    
            data = b''.join([b''.join(chunk) for chunk in chunks_to_recv])
    
        return th.load(io.BytesIO(data), **kwargs)
    
    # 使用新的函数替代原始的load_state_dict
    state_dict = load_state_dict_mpichunked("your_file_path.pth", **kwargs)
    

    这个修改版的代码将接收端的处理提前到每次接收到一个数据块时就进行。这样,即使是最小的数据块,其他进程也不会阻塞等待。同时,我们从文件读取数据的部分也进行了优化,直接使用 Python 的文件操作,而不是通过 bf.BlobFile

    请注意,您需要安装 mpi4py 库才能运行这段代码,可以使用 pip 安装:pip install mpi4py。此外,确保您的环境中安装了 OpenMPI 或者 MPICH 等 MPI 实现。

    评论 编辑记录

报告相同问题?

问题事件

  • 系统已结题 6月7日
  • 创建了问题 5月30日

悬赏问题

  • ¥15 网络科学导论,网络控制
  • ¥100 安卓tv程序连接SQLSERVER2008问题
  • ¥15 利用Sentinel-2和Landsat8做一个水库的长时序NDVI的对比,为什么Snetinel-2计算的结果最小值特别小,而Lansat8就很平均
  • ¥15 metadata提取的PDF元数据,如何转换为一个Excel
  • ¥15 关于arduino编程toCharArray()函数的使用
  • ¥100 vc++混合CEF采用CLR方式编译报错
  • ¥15 coze 的插件输入飞书多维表格 app_token 后一直显示错误,如何解决?
  • ¥15 vite+vue3+plyr播放本地public文件夹下视频无法加载
  • ¥15 c#逐行读取txt文本,但是每一行里面数据之间空格数量不同
  • ¥50 如何openEuler 22.03上安装配置drbd