Fix cutlass moe accuracy drop caused by attention UB from DP padding mode (#10414)
This commit is contained in:
@@ -51,7 +51,12 @@ class DpPaddingMode(IntEnum):
|
||||
return self == DpPaddingMode.SUM_LEN
|
||||
|
||||
@classmethod
|
||||
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
|
||||
def get_dp_padding_mode(
|
||||
cls, is_extend_in_batch, global_num_tokens: List[int]
|
||||
) -> DpPaddingMode:
|
||||
if is_extend_in_batch:
|
||||
return DpPaddingMode.SUM_LEN
|
||||
|
||||
# we choose the mode that minimizes the communication cost
|
||||
max_len = max(global_num_tokens)
|
||||
sum_len = sum(global_num_tokens)
|
||||
|
||||
@@ -686,7 +686,9 @@ class ForwardBatch:
|
||||
(global_num_tokens[i] - 1) // attn_tp_size + 1
|
||||
) * attn_tp_size
|
||||
|
||||
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens)
|
||||
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(
|
||||
self.is_extend_in_batch, global_num_tokens
|
||||
)
|
||||
self.dp_padding_mode = dp_padding_mode
|
||||
|
||||
if dp_padding_mode.is_max_len():
|
||||
|
||||
Reference in New Issue
Block a user