名字不能取太长 2024-05-30 14:09 采纳率: 76.7%
浏览 11
已结题

使用MPI广播数据遇到阻塞

以下代码使用了MPI广播数据,但是我发现它存在阻塞的问题,代码如下:


def load_state_dict(path, **kwargs):
    """
    Load a PyTorch file without redundant fetches across MPI ranks.
    """
    chunk_size = 2**30  # MPI has a relatively small size limit
    if MPI.COMM_WORLD.Get_rank() == 0:
        with bf.BlobFile(path, "rb") as f:
            data = f.read()
        num_chunks = len(data) // chunk_size
        if len(data) % chunk_size:
            num_chunks += 1
        MPI.COMM_WORLD.bcast(num_chunks)
        for i in range(0, len(data), chunk_size):
            MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
    else:
        num_chunks = MPI.COMM_WORLD.bcast(None)
        data = bytes()
        for _ in range(num_chunks):
            data += MPI.COMM_WORLD.bcast(None)

    return th.load(io.BytesIO(data), **kwargs)

调试后发现阻塞位置在这个地方,在循环最后一次发生了阻塞:

for i in range(0, len(data), chunk_size):
            MPI.COMM_WORLD.bcast(data[i : i + chunk_size])

请问如何解决?为什么这里会发生阻塞?

  • 写回答

27条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2024-05-30 14:09
    关注
    获得0.40元问题酬金

    引自免费微信小程序:皆我百晓生

    为了解决这个问题,我们需要理解 MPI 的工作原理。MPI(Message Passing Interface)是一种用于并行计算的标准库,允许进程间通信。在您的代码中,所有非主进程都在等待最后一个块的数据广播完成,然后才会继续执行。这是因为 bcast 函数是同步的,这意味着发送方会一直等待直到所有接收方都接收到数据。

    要解决这个问题,我们可以稍微调整代码结构,使得非主进程在开始接收数据之前就开始处理已经接收到的块。这样可以避免在最后一轮广播时的阻塞。这里是一个修改后的版本:

    import torch as th
    from mpi4py import MPI
    
    def load_state_dict_mpichunked(path, **kwargs):
        """
        Load a PyTorch file without redundant fetches across MPI ranks using a chunk-based approach.
        """
        chunk_size = 2**30  # MPI has a relatively small size limit
        comm = MPI.COMM_WORLD
    
        if comm.Get_rank() == 0:
            with open(path, "rb") as f:
                data = f.read()
            num_chunks = len(data) // chunk_size
            if len(data) % chunk_size:
                num_chunks += 1
            comm.bcast(num_chunks)
    
            chunks_to_send = [data[i:i+chunk_size] for i in range(0, len(data), chunk_size)]
            for chunk in chunks_to_send:
                comm.bcast(chunk)
        else:
            num_chunks = comm.bcast(None)
            chunks_to_recv = [[] for _ in range(num_chunks)]
            for i in range(num_chunks):
                chunk = comm.bcast(None)
                chunks_to_recv[i].append(chunk)
    
            data = b''.join([b''.join(chunk) for chunk in chunks_to_recv])
    
        return th.load(io.BytesIO(data), **kwargs)
    
    # 使用新的函数替代原始的load_state_dict
    state_dict = load_state_dict_mpichunked("your_file_path.pth", **kwargs)
    

    这个修改版的代码将接收端的处理提前到每次接收到一个数据块时就进行。这样,即使是最小的数据块,其他进程也不会阻塞等待。同时,我们从文件读取数据的部分也进行了优化,直接使用 Python 的文件操作,而不是通过 bf.BlobFile

    请注意,您需要安装 mpi4py 库才能运行这段代码,可以使用 pip 安装:pip install mpi4py。此外,确保您的环境中安装了 OpenMPI 或者 MPICH 等 MPI 实现。

    评论 编辑记录

报告相同问题?

问题事件

  • 系统已结题 6月7日
  • 创建了问题 5月30日

悬赏问题

  • ¥15 爬取网页内容并保存需要完整的python代码
  • ¥30 NIRfast软件使用指导
  • ¥20 matlab仿真问题,求功率谱密度
  • ¥15 求micropython modbus-RTU 从机的代码或库?
  • ¥15 铜与钢双金属板叠加在一起每种材料300mm长,18mm宽,4mm厚一端固定并加热至80℃,当加热端温度保持不变时另一端的稳态温度。ansys
  • ¥15 django5安装失败
  • ¥15 Java与Hbase相关问题
  • ¥15 后缀 crn 游戏文件提取资源
  • ¥15 ANSYS分析简单钎焊问题
  • ¥20 bash代码推送不上去 git fetch origin master #失败了