Handle with_prefill_across_dp for multistream mla (#1322)

### What this PR does / why we need it?
After #1094, decode might be executed with non-compiled mode, despite of
`torchair_graph_config.enabled`, causing multistream mla to fail, which
assumes torchair compiled mode for decode when
`torchair_graph_config.enabled == True`.
Augment that assumption to fix this.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Tested both offline, and by graph mode mla e2e testcase.

---------

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
This commit is contained in:
sdmyzlp
2025-06-26 09:32:07 +08:00
committed by GitHub
parent 2690697caa
commit 53c2d58ae1
3 changed files with 82 additions and 60 deletions

View File

@@ -555,20 +555,21 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
enable_multistream_mla = (self.enable_multistream_mla
and attn_metadata is not None
and not attn_metadata.with_prefill_across_dp
and attn_metadata.num_decodes > 0)
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
use_multistream_mla = (self.enable_multistream_mla
and attn_metadata is not None
and attn_metadata.num_decodes > 0)
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
npu_wait_tensor(hidden_states, ckq, enabled=enable_multistream_mla)
with npu_stream_switch("mla_secondary",
0,
enabled=use_multistream_mla):
enabled=enable_multistream_mla):
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_q_c = hidden_states
if self.torchair_graph_enabled:
forward_kwargs = {}
if envs.VLLM_USE_V1:
output_shape = hidden_states.shape
output = torch.empty(output_shape,