[bugfix] fix accuracy prolem for deepseek V3/R1 models with torchair graph in long sequence predictions (#1331)
### What this PR does / why we need it? Fix the issue of insufficient cached cosine and sine length in MLA's TorchAir graph mode, which causes accuracy deviation during long-sequence inference. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? We tested the accuracy of this patch with DeepSeek R1 e2e becnhmark serving, and get 83.33 sore for AIME2024 dataset with DP4TP4EP16 setting. Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -1077,7 +1077,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
decode_k_nope = None
|
decode_k_nope = None
|
||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
if self.running_in_graph:
|
if self.running_in_graph:
|
||||||
seq_len = self.rotary_emb.max_position_embeddings
|
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
|
||||||
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
||||||
dtype=decode_hs_or_q_c.dtype)
|
dtype=decode_hs_or_q_c.dtype)
|
||||||
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
||||||
@@ -1122,7 +1122,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
|
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
|
||||||
if self.torchair_graph_enabled:
|
if self.torchair_graph_enabled:
|
||||||
num_tokens = prefill_hs_or_q_c.shape[0]
|
num_tokens = prefill_hs_or_q_c.shape[0]
|
||||||
seq_len = self.rotary_emb.max_position_embeddings
|
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
|
||||||
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
||||||
dtype=prefill_q_pe.dtype)
|
dtype=prefill_q_pe.dtype)
|
||||||
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
||||||
|
|||||||
Reference in New Issue
Block a user