以下代码使用了MPI广播数据,但是我发现它存在阻塞的问题,代码如下:
def load_state_dict(path, **kwargs):
"""
Load a PyTorch file without redundant fetches across MPI ranks.
"""
chunk_size = 2**30 # MPI has a relatively small size limit
if MPI.COMM_WORLD.Get_rank() == 0:
with bf.BlobFile(path, "rb") as f:
data = f.read()
num_chunks = len(data) // chunk_size
if len(data) % chunk_size:
num_chunks += 1
MPI.COMM_WORLD.bcast(num_chunks)
for i in range(0, len(data), chunk_size):
MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
else:
num_chunks = MPI.COMM_WORLD.bcast(None)
data = bytes()
for _ in range(num_chunks):
data += MPI.COMM_WORLD.bcast(None)
return th.load(io.BytesIO(data), **kwargs)
调试后发现阻塞位置在这个地方,在循环最后一次发生了阻塞:
for i in range(0, len(data), chunk_size):
MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
请问如何解决?为什么这里会发生阻塞?