Fix cutlass moe accuracy drop caused by attention UB from DP padding mode (#10414)

This commit is contained in:
fzyzcjy
2025-09-14 13:29:09 +08:00
committed by GitHub
parent 05b01ef4da
commit 72dfa96aeb
2 changed files with 9 additions and 2 deletions

View File

@@ -51,7 +51,12 @@ class DpPaddingMode(IntEnum):
return self == DpPaddingMode.SUM_LEN
@classmethod
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
def get_dp_padding_mode(
cls, is_extend_in_batch, global_num_tokens: List[int]
) -> DpPaddingMode:
if is_extend_in_batch:
return DpPaddingMode.SUM_LEN
# we choose the mode that minimizes the communication cost
max_len = max(global_num_tokens)
sum_len = sum(global_num_tokens)

View File

@@ -686,7 +686,9 @@ class ForwardBatch:
(global_num_tokens[i] - 1) // attn_tp_size + 1
) * attn_tp_size
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens)
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(
self.is_extend_in_batch, global_num_tokens
)
self.dp_padding_mode = dp_padding_mode
if dp_padding_mode.is_max_len():