起因
我在尝试实现dualpipe中的两个model chunk前向和反向之间的overlap。我打算先用torch.cuda.stream简单实现,一下:
def overlapped_forward_backward(
module0: torch.nn.Module,
inputs0: List[torch.Tensor],
labels0: Optional[List[torch.Tensor]],
loss_masks0: Optional[List[torch.Tensor]],
loss1: Optional[torch.Tensor],
outputs1: Optional[List[torch.Tensor]],
output_grads1: Optional[List[torch.Tensor]],
forward_step_func: Callable,
is_last_stage0: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
You should implement custom forward-backward overlap strategy.
The code below is just an example.
"""
device = inputs0[0].device
if not hasattr(overlapped_forward_backward, 'backward_streams'):
overlapped_forward_backward.backward_streams = {}
if device not in overlapped_forward_backward.backward_streams:
overlapped_forward_backward.backward_streams[device] = torch.cuda.Stream(device=device)
backward_stream = overlapped_forward_backward.backward_streams[device]
with torch.cuda.stream(backward_stream):
if loss1 is not None:
loss1.backward()
loss1.detach_()
else:
run_backward(outputs1, output_grads1)
if len(inputs0) == 1:
from megatron.core.utils import get_attr_wrapped_model
set_input_tensor = get_attr_wrapped_model(module0, "set_input_tensor")
set_input_tensor(inputs0)
if is_last_stage0:
inputs0_with_labels_loss_masks = list(inputs0)
inputs0_with_labels_loss_masks.append(labels0)
inputs0_with_labels_loss_masks.append(loss_masks0)
outputs0, loss_func = forward_step_func(inputs0_with_labels_loss_masks, module0)
else:
outputs0, loss_func = forward_step_func(inputs0, module0)
outputs0 = [outputs0] if isinstance(outputs0, torch.Tensor) else outputs0
if is_last_stage0:
loss0 = loss_func(outputs0[0])[0]
else:
loss0 = None
torch.cuda.current_stream().wait_stream(backward_stream)
return outputs0, loss0
我发现前向和反向并没有overlap,反向那块代码运行时间和没有加stream的情况是相同的。
小实验
于是我做了下面这个小实验:
import torch
import time
# GPU warmup
a = torch.randn(10000, 10000, device='cuda')
b = torch.randn(10000, 10000, device='cuda')
c = torch.mm(a, b)
# Count calculation time
calc_start = time.time()
a = torch.randn(10000, 10000, device='cuda')
b = torch.randn(10000, 10000, device='cuda')
for i in range(100):
c = torch.mm(a, b)
calc_end = time.time()
print(f"calc time: {calc_end - calc_start}")
# Count stream time
calc_stream = torch.cuda.Stream()
torch.cuda.synchronize()
stream_start = time.time()
with torch.cuda.stream(calc_stream):
a = torch.randn(10000, 10000, device='cuda')
b = torch.randn(10000, 10000, device='cuda')
for i in range(100):
c = torch.mm(a, b)
stream_end = time.time()
print(f"stream time: {stream_end - stream_start}")
torch.cuda.synchronize()
测试了用stream和不用stream的时间,发现时间是差不多的,stream的时间反而更长:

如果增加循环数量的话,使不使用stream的时间都会相应的增加。
问题
按照我的理解,"with torch.cuda.stream()"应该是异步的,不会阻塞,其中的代码块运行时间应该趋近于零。这样才可以和后续的程序并行计算或通信。但是实验结果貌似说明它是阻塞的,整个计算做完之后才会往后继续执行。我代码哪里不对吗,还是有其他可能原因,请指点一下!