Support multistream of MLA vector operations (#1135)

### What this PR does / why we need it?
Move all vector operations to a secondary stream, with the expected
overlaping being:
```
              | q_rmsnorm |                  | kv_norm_rope_cache |       | q_rope |
| matmul W_DQ | matmul W_DKV | index | index |    matmul W_UQ     | split | matmul W_KV_T |
```

Currently, the `IndexByTensor` operators introduced by computation of
`cos` and `sin` can't be offloaded to the secondary stream due to a
known bug of graph fusion optimization pass. So we instead keep it in
the main stream, only requires it be computed before `matmul W_UQ` to
avoid hindering later overlapping. The problem may be solved by later
optimization (#993), which hoists the computation of `cos` and `sin` up
to the first layer.

### Does this PR introduce _any_ user-facing change?
Controlled by `torchair_graph_config.enable_multistream_mla`, defaulted
to False.

### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
This commit is contained in:
sdmyzlp
2025-06-12 21:42:09 +08:00
committed by GitHub
parent 55c0e68883
commit e72f94e38f
5 changed files with 56 additions and 19 deletions

View File

@@ -71,7 +71,8 @@ from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
npu_wait_tensor)
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
@@ -496,6 +497,8 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla
def forward(
self,
@@ -505,7 +508,14 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
use_multistream_mla = (self.enable_multistream_mla
and attn_metadata is not None
and attn_metadata.num_decodes > 0)
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
with npu_stream_switch("mla_secondary",
0,
enabled=use_multistream_mla):
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_q_c = hidden_states
if self.torchair_graph_enabled: