diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 82c73ec94..40f6799a1 100755 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -848,12 +848,12 @@ class DeepseekV2AttentionMLA(nn.Module): def all_gather( input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group ): - if world_size == 1: - return input_tensor - all_lens = forward_batch.global_num_tokens_cpu max_len = max(forward_batch.global_num_tokens_cpu) + if world_size == 1: + return input_tensor, 0, all_lens[0] + padded_tensor = torch.nn.functional.pad( input_tensor, (0, 0, 0, max_len - input_tensor.shape[0]) )