From e72f94e38f4208d36a20df1ff85ad4abf430447d Mon Sep 17 00:00:00 2001 From: sdmyzlp <117554856+sdmyzlp@users.noreply.github.com> Date: Thu, 12 Jun 2025 21:42:09 +0800 Subject: [PATCH] Support multistream of MLA vector operations (#1135) ### What this PR does / why we need it? Move all vector operations to a secondary stream, with the expected overlaping being: ``` | q_rmsnorm | | kv_norm_rope_cache | | q_rope | | matmul W_DQ | matmul W_DKV | index | index | matmul W_UQ | split | matmul W_KV_T | ``` Currently, the `IndexByTensor` operators introduced by computation of `cos` and `sin` can't be offloaded to the secondary stream due to a known bug of graph fusion optimization pass. So we instead keep it in the main stream, only requires it be computed before `matmul W_UQ` to avoid hindering later overlapping. The problem may be solved by later optimization (#993), which hoists the computation of `cos` and `sin` up to the first layer. ### Does this PR introduce _any_ user-facing change? Controlled by `torchair_graph_config.enable_multistream_mla`, defaulted to False. ### How was this patch tested? Tested on 1x16 910 node, with tailored 2 layer DSKv2. Signed-off-by: sdmyzlp --- docs/source/user_guide/additional_config.md | 1 + tests/singlecard/test_ascend_config.py | 2 + vllm_ascend/ascend_config.py | 2 + vllm_ascend/attention/mla_v1.py | 56 ++++++++++++++------- vllm_ascend/models/deepseek_v2.py | 14 +++++- 5 files changed, 56 insertions(+), 19 deletions(-) diff --git a/docs/source/user_guide/additional_config.md b/docs/source/user_guide/additional_config.md index 90002db..4608326 100644 --- a/docs/source/user_guide/additional_config.md +++ b/docs/source/user_guide/additional_config.md @@ -39,6 +39,7 @@ The details of each config option are as follows: | Name | Type | Default | Description | | ---- | ---- | ------- | ----------- | | `enabled` | bool | `False` | Whether to enable torchair graph mode | +| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream | | `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert | | `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization | | `use_cached_graph` | bool | `False` | Whether to use cached graph | diff --git a/tests/singlecard/test_ascend_config.py b/tests/singlecard/test_ascend_config.py index 818745f..63484d4 100644 --- a/tests/singlecard/test_ascend_config.py +++ b/tests/singlecard/test_ascend_config.py @@ -59,6 +59,7 @@ def test_run_with_ascend_config(): "graph_batch_sizes": [1, 2, 4, 8], "graph_batch_sizes_init": False, "enable_multistream_moe": True, + "enable_multistream_mla": True, }, "ascend_scheduler_config": { "enabled": True, @@ -79,6 +80,7 @@ def test_run_with_ascend_config(): 1, 2, 4, 8 ] assert not ascend_config.torchair_graph_config.graph_batch_sizes_init + assert ascend_config.torchair_graph_config.enable_multistream_mla assert ascend_config.torchair_graph_config.enable_multistream_moe assert ascend_config.ascend_scheduler_config.enabled assert ascend_config.ascend_scheduler_config.enable_chunked_prefill diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index f7441f5..2d34283 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -54,6 +54,8 @@ class TorchairGraphConfig: "graph_batch_sizes", []) self.graph_batch_sizes_init = torchair_graph_config.get( "graph_batch_sizes_init", False) + self.enable_multistream_mla = torchair_graph_config.get( + "enable_multistream_mla", False) self.enable_multistream_moe = torchair_graph_config.get( "enable_multistream_moe", False) self.enable_view_optimize = torchair_graph_config.get( diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 9e10815..e07d59a 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -19,6 +19,7 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla +from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -481,6 +482,9 @@ class AscendMLAImpl(MLAAttentionImpl): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz + self.enable_multistream_mla = \ + ascend_config.torchair_graph_config.enable_multistream_mla + # Adapt torch air graph mode with spec decoding. speculative_config = get_current_vllm_config().speculative_config if speculative_config is not None: @@ -664,17 +668,20 @@ class AscendMLAImpl(MLAAttentionImpl): # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" - k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( - kv, - self.kv_a_layernorm.weight, - cos, - sin, - slots.to(torch.int64), - kv_cache[1], - kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, - cache_mode=cache_mode, - ) + with npu_stream_switch("mla_secondary", + 0, + enabled=self.enable_multistream_mla): + k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + kv, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + ) return k_pe, k_nope def exec_kv_prefill( @@ -867,23 +874,38 @@ class AscendMLAImpl(MLAAttentionImpl): if has_decode: decode_k_nope = None assert attn_metadata.decode is not None - decode_ql_nope, decode_q_pe = \ - self._q_proj_and_k_up_proj(decode_hs_or_q_c) if self.running_in_graph: seq_len = self.rotary_emb.max_position_embeddings cos = self.rotary_emb.cos_cached[:seq_len].to( - dtype=decode_q_pe.dtype) + dtype=decode_hs_or_q_c.dtype) sin = self.rotary_emb.sin_cached[:seq_len].to( - dtype=decode_q_pe.dtype) + dtype=decode_hs_or_q_c.dtype) cos = cos[attn_metadata.decode.input_positions] sin = sin[attn_metadata.decode.input_positions] cos = cos[:, None, None, :] sin = sin[:, None, None, :] - - decode_q_pe = self.rope_single(decode_q_pe, cos, sin) + # Without explicitly controlling the order, IndexByTensor operations + # would be placed after `matmul W_KV_T` hindering the overlapping of + # KvRmsNormRopeCache and SingleRope. + npu_wait_tensor(decode_hs_or_q_c, + cos, + enabled=self.enable_multistream_mla) + npu_wait_tensor(decode_hs_or_q_c, + sin, + enabled=self.enable_multistream_mla) + decode_ql_nope, decode_q_pe = \ + self._q_proj_and_k_up_proj(decode_hs_or_q_c) + if self.running_in_graph: decode_k_pe, decode_k_nope = self.exec_kv( hidden_states_or_kv_c_normed, cos, sin, kv_cache, attn_metadata.slot_mapping) + with npu_stream_switch("mla_secondary", + 0, + enabled=self.enable_multistream_mla): + npu_wait_tensor(decode_q_pe, + decode_k_pe, + enabled=self.enable_multistream_mla) + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) else: decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( attn_metadata.decode.input_positions, diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index a83ca47..0ae1142 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -71,7 +71,8 @@ from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod -from vllm_ascend.utils import dispose_tensor +from vllm_ascend.utils import (dispose_tensor, npu_stream_switch, + npu_wait_tensor) VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 @@ -496,6 +497,8 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_mla = \ + ascend_config.torchair_graph_config.enable_multistream_mla def forward( self, @@ -505,7 +508,14 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: if self.q_lora_rank is not None: ckq = self.q_a_proj(hidden_states)[0] - hidden_states_or_q_c = self.q_a_layernorm(ckq) + use_multistream_mla = (self.enable_multistream_mla + and attn_metadata is not None + and attn_metadata.num_decodes > 0) + npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla) + with npu_stream_switch("mla_secondary", + 0, + enabled=use_multistream_mla): + hidden_states_or_q_c = self.q_a_layernorm(ckq) else: hidden_states_or_q_c = hidden_states if self.torchair_graph_enabled: