[Bug fixed] fixed the crash when enable the dp-attention on the single card (#3958)
This commit is contained in:
@@ -848,12 +848,12 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
def all_gather(
|
def all_gather(
|
||||||
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
|
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
|
||||||
):
|
):
|
||||||
if world_size == 1:
|
|
||||||
return input_tensor
|
|
||||||
|
|
||||||
all_lens = forward_batch.global_num_tokens_cpu
|
all_lens = forward_batch.global_num_tokens_cpu
|
||||||
max_len = max(forward_batch.global_num_tokens_cpu)
|
max_len = max(forward_batch.global_num_tokens_cpu)
|
||||||
|
|
||||||
|
if world_size == 1:
|
||||||
|
return input_tensor, 0, all_lens[0]
|
||||||
|
|
||||||
padded_tensor = torch.nn.functional.pad(
|
padded_tensor = torch.nn.functional.pad(
|
||||||
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user