Fix allgather ops inside cuda graphs (#3709)
This commit is contained in:
@@ -824,9 +824,7 @@ def all_gather(
|
||||
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
||||
)
|
||||
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
forward_batch.gathered_buffer, padded_tensor, group=group
|
||||
)
|
||||
group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
|
||||
|
||||
gathered_tensors = torch.concat(
|
||||
[
|
||||
@@ -862,7 +860,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
if self.enable_dp_attention:
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.tp_group = get_tp_group()
|
||||
if not global_server_args_dict["disable_mla"]:
|
||||
self.self_attn = DeepseekV2AttentionMLA(
|
||||
config=config,
|
||||
|
||||
Reference in New Issue
Block a user