Support multistream of MLA vector operations (#1135)

### What this PR does / why we need it?
Move all vector operations to a secondary stream, with the expected
overlaping being:
```
              | q_rmsnorm |                  | kv_norm_rope_cache |       | q_rope |
| matmul W_DQ | matmul W_DKV | index | index |    matmul W_UQ     | split | matmul W_KV_T |
```

Currently, the `IndexByTensor` operators introduced by computation of
`cos` and `sin` can't be offloaded to the secondary stream due to a
known bug of graph fusion optimization pass. So we instead keep it in
the main stream, only requires it be computed before `matmul W_UQ` to
avoid hindering later overlapping. The problem may be solved by later
optimization (#993), which hoists the computation of `cos` and `sin` up
to the first layer.

### Does this PR introduce _any_ user-facing change?
Controlled by `torchair_graph_config.enable_multistream_mla`, defaulted
to False.

### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
This commit is contained in:
sdmyzlp
2025-06-12 21:42:09 +08:00
committed by GitHub
parent 55c0e68883
commit e72f94e38f
5 changed files with 56 additions and 19 deletions

View File

@@ -39,6 +39,7 @@ The details of each config option are as follows:
| Name | Type | Default | Description |
| ---- | ---- | ------- | ----------- |
| `enabled` | bool | `False` | Whether to enable torchair graph mode |
| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream |
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert |
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
| `use_cached_graph` | bool | `False` | Whether to use cached graph |

View File

@@ -59,6 +59,7 @@ def test_run_with_ascend_config():
"graph_batch_sizes": [1, 2, 4, 8],
"graph_batch_sizes_init": False,
"enable_multistream_moe": True,
"enable_multistream_mla": True,
},
"ascend_scheduler_config": {
"enabled": True,
@@ -79,6 +80,7 @@ def test_run_with_ascend_config():
1, 2, 4, 8
]
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
assert ascend_config.torchair_graph_config.enable_multistream_mla
assert ascend_config.torchair_graph_config.enable_multistream_moe
assert ascend_config.ascend_scheduler_config.enabled
assert ascend_config.ascend_scheduler_config.enable_chunked_prefill

View File

@@ -54,6 +54,8 @@ class TorchairGraphConfig:
"graph_batch_sizes", [])
self.graph_batch_sizes_init = torchair_graph_config.get(
"graph_batch_sizes_init", False)
self.enable_multistream_mla = torchair_graph_config.get(
"enable_multistream_mla", False)
self.enable_multistream_moe = torchair_graph_config.get(
"enable_multistream_moe", False)
self.enable_view_optimize = torchair_graph_config.get(

View File

@@ -19,6 +19,7 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -481,6 +482,9 @@ class AscendMLAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla
# Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config
if speculative_config is not None:
@@ -664,17 +668,20 @@ class AscendMLAImpl(MLAAttentionImpl):
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
with npu_stream_switch("mla_secondary",
0,
enabled=self.enable_multistream_mla):
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return k_pe, k_nope
def exec_kv_prefill(
@@ -867,23 +874,38 @@ class AscendMLAImpl(MLAAttentionImpl):
if has_decode:
decode_k_nope = None
assert attn_metadata.decode is not None
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
if self.running_in_graph:
seq_len = self.rotary_emb.max_position_embeddings
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=decode_q_pe.dtype)
dtype=decode_hs_or_q_c.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=decode_q_pe.dtype)
dtype=decode_hs_or_q_c.dtype)
cos = cos[attn_metadata.decode.input_positions]
sin = sin[attn_metadata.decode.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
# Without explicitly controlling the order, IndexByTensor operations
# would be placed after `matmul W_KV_T` hindering the overlapping of
# KvRmsNormRopeCache and SingleRope.
npu_wait_tensor(decode_hs_or_q_c,
cos,
enabled=self.enable_multistream_mla)
npu_wait_tensor(decode_hs_or_q_c,
sin,
enabled=self.enable_multistream_mla)
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
if self.running_in_graph:
decode_k_pe, decode_k_nope = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
with npu_stream_switch("mla_secondary",
0,
enabled=self.enable_multistream_mla):
npu_wait_tensor(decode_q_pe,
decode_k_pe,
enabled=self.enable_multistream_mla)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
else:
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions,

View File

@@ -71,7 +71,8 @@ from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
npu_wait_tensor)
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
@@ -496,6 +497,8 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla
def forward(
self,
@@ -505,7 +508,14 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
use_multistream_mla = (self.enable_multistream_mla
and attn_metadata is not None
and attn_metadata.num_decodes > 0)
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
with npu_stream_switch("mla_secondary",
0,
enabled=use_multistream_mla):
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_q_c = hidden_states
if self.torchair_graph_enabled: