Handle with_prefill_across_dp for multistream mla (#1322)
### What this PR does / why we need it? After #1094, decode might be executed with non-compiled mode, despite of `torchair_graph_config.enabled`, causing multistream mla to fail, which assumes torchair compiled mode for decode when `torchair_graph_config.enabled == True`. Augment that assumption to fix this. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Tested both offline, and by graph mode mla e2e testcase. --------- Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
This commit is contained in:
@@ -555,20 +555,21 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
enable_multistream_mla = (self.enable_multistream_mla
|
||||
and attn_metadata is not None
|
||||
and not attn_metadata.with_prefill_across_dp
|
||||
and attn_metadata.num_decodes > 0)
|
||||
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
|
||||
if self.q_lora_rank is not None:
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
use_multistream_mla = (self.enable_multistream_mla
|
||||
and attn_metadata is not None
|
||||
and attn_metadata.num_decodes > 0)
|
||||
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
|
||||
npu_wait_tensor(hidden_states, ckq, enabled=enable_multistream_mla)
|
||||
with npu_stream_switch("mla_secondary",
|
||||
0,
|
||||
enabled=use_multistream_mla):
|
||||
enabled=enable_multistream_mla):
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
if self.torchair_graph_enabled:
|
||||
forward_kwargs = {}
|
||||
if envs.VLLM_USE_V1:
|
||||
output_shape = hidden_states.shape
|
||||
output = torch.empty(output_shape,
|
||||
|
||||
Reference in New Issue
Block a user