Fix allgather ops inside cuda graphs (#3709)

This commit is contained in:
Nicolas Castet
2025-02-25 10:39:10 -06:00
committed by GitHub
parent c0bb9eb3b3
commit 127998cc41
2 changed files with 41 additions and 7 deletions

View File

@@ -824,9 +824,7 @@ def all_gather(
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
)
torch.distributed.all_gather_into_tensor(
forward_batch.gathered_buffer, padded_tensor, group=group
)
group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
gathered_tensors = torch.concat(
[
@@ -862,7 +860,7 @@ class DeepseekV2DecoderLayer(nn.Module):
if self.enable_dp_attention:
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group().device_group
self.tp_group = get_tp_group()
if not global_server_args_dict["disable_mla"]:
self.self_attn = DeepseekV2AttentionMLA(
config=config,