diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index b7feccdb6..3bae7e0c3 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -51,7 +51,12 @@ class DpPaddingMode(IntEnum): return self == DpPaddingMode.SUM_LEN @classmethod - def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode: + def get_dp_padding_mode( + cls, is_extend_in_batch, global_num_tokens: List[int] + ) -> DpPaddingMode: + if is_extend_in_batch: + return DpPaddingMode.SUM_LEN + # we choose the mode that minimizes the communication cost max_len = max(global_num_tokens) sum_len = sum(global_num_tokens) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index e343d6b4f..017b5863c 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -686,7 +686,9 @@ class ForwardBatch: (global_num_tokens[i] - 1) // attn_tp_size + 1 ) * attn_tp_size - dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens) + dp_padding_mode = DpPaddingMode.get_dp_padding_mode( + self.is_extend_in_batch, global_num_tokens + ) self.dp_padding_mode = dp_padding_mode if dp_padding_mode.is_max_len():