From 88ca8a051ca51fe72516344db092a7852150cfdb Mon Sep 17 00:00:00 2001 From: yiz-liu <136800916+yiz-liu@users.noreply.github.com> Date: Tue, 16 Sep 2025 17:50:17 +0800 Subject: [PATCH] [Feat][Graph] Support DeepSeek with ACL Graph (#2707) ### What this PR does / why we need it? In memory of #677 , a long overdue milestone. Now DeepSeek V3/R1 should be OK with ACL Graph. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Working on it. - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/68dbde5dbb11b9250454d0c9f21a8b3da960b341 --------- Signed-off-by: Yizhou Liu --- tests/ut/models/test_deepseek_v2.py | 9 ++-- tests/ut/test_ascend_config.py | 15 ------- vllm_ascend/ascend_config.py | 6 --- vllm_ascend/models/deepseek_v2.py | 2 +- vllm_ascend/models/layers/mla.py | 61 +++++++++++++++++++++------ vllm_ascend/platform.py | 5 ++- vllm_ascend/worker/model_runner_v1.py | 8 +++- 7 files changed, 64 insertions(+), 42 deletions(-) diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py index 2e3b5f3..693aea5 100644 --- a/tests/ut/models/test_deepseek_v2.py +++ b/tests/ut/models/test_deepseek_v2.py @@ -41,9 +41,10 @@ def test_row_parallel_linear(cls, mock_distributed): assert output[0].shape == (2, 4, 64) +@patch("torch.ops.vllm.mla_forward") @patch("torch_npu.npu_rms_norm") -def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed, - base_config): +def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward, + mock_distributed, base_config): mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) attn = CustomDeepseekV2MLAAttention(config=base_config, @@ -64,8 +65,8 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed, with patch.object(attn.mla_attn, "__call__", return_value=torch.randn(2, 4, 128)): - with pytest.raises(AssertionError): - attn(positions, x) + attn(positions, x) + mock_mla_forward.assert_called_once() attn = CustomDeepseekV2MLAAttention(config=base_config, hidden_size=128, diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index c8013fb..4abec5d 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -215,21 +215,6 @@ class TestAscendConfig(TestBase): test_vllm_config.model_config = fake_model_config init_ascend_config(test_vllm_config) check_ascend_config(test_vllm_config, False) - # aclgraph + deepseek model - with self.assertRaises(NotImplementedError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - }, - "refresh": True - } - model_path = os.path.join(os.path.dirname(__file__), "fake_weight") - fake_model_config = ModelConfig(model=model_path) - fake_model_config.hf_config = PretrainedConfig() - fake_model_config.hf_config.model_type = "deepseek" - test_vllm_config.model_config = fake_model_config - init_ascend_config(test_vllm_config) - check_ascend_config(test_vllm_config, False) def test_check_torchair_supported(self): test_cases = [('deepseek_v3', True), ('PanguProMoE', True), diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index d053387..6a61cdd 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -218,14 +218,8 @@ def check_ascend_config(vllm_config, enforce_eager): "it has been disabled automatically.") # aclgraph case else: - # aclgraph doesn't work with deepseek model and only qwen model is well tested. if vllm_config.model_config: model_type = vllm_config.model_config.hf_config.model_type - if "deepseek" in model_type: - raise NotImplementedError( - "ACL Graph does not support deepseek. Please " - "try torchair graph mode to serve deepseek models on vllm-ascend." - " Or set `enforce_eager=True` to use eager mode.") if "qwen" not in model_type: logger.warning( "ACL Graph is currently experimental. Please " diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 502542e..7d78a0b 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -50,7 +50,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.deepseek_v2 import \ yarn_get_mscale # noqa: E501 -from vllm.model_executor.models.deepseek_v2 import ( # noqa: E501 +from vllm.model_executor.models.deepseek_v2 import ( DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM, DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE, get_spec_layer_idx_from_weight_name) diff --git a/vllm_ascend/models/layers/mla.py b/vllm_ascend/models/layers/mla.py index 0f9adf4..fa5317c 100644 --- a/vllm_ascend/models/layers/mla.py +++ b/vllm_ascend/models/layers/mla.py @@ -25,10 +25,11 @@ from typing import Optional import torch from torch import nn from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig -from vllm.forward_context import get_forward_context +from vllm.config import CacheConfig, get_current_vllm_config +from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.mla import MultiHeadLatentAttention from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.utils import direct_register_custom_op @dataclass @@ -80,6 +81,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention): self.qk_nope_head_dim = qk_nope_head_dim self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim + self.prefix = prefix self.mla_attn = Attention( num_heads=self.num_local_heads, @@ -107,15 +109,17 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention): o_proj=mla_modules.o_proj, ) + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - forward_context = get_forward_context() - if kv_cache is None: - kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine] num_tokens = hidden_states.shape[0] need_gather_q_kv = False if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: @@ -129,16 +133,47 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention): if num_tokens % self.tp_size: rows += 1 output_shape = (rows, hidden_states.shape[1]) + # FIXME: This does not seem right, should make sure the buffer is fixed output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device) - if forward_context.attn_metadata: - attn_metadata = forward_context.attn_metadata[ - self.mla_attn.layer_name] - else: - attn_metadata = forward_context.attn_metadata - output = self.mla_attn.impl.forward(hidden_states, kv_cache, - attn_metadata, need_gather_q_kv, - output) + torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output, + self.prefix) output = output.view(-1, output_shape[-1]) return output + + +def mla_forward( + hidden_states: torch.Tensor, + need_gather_q_kv: bool, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + if forward_context.attn_metadata: + attn_metadata = forward_context.attn_metadata[self.mla_attn.layer_name] + else: + attn_metadata = forward_context.attn_metadata + kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine] + self.mla_attn.impl.forward(hidden_states, kv_cache, attn_metadata, + need_gather_q_kv, output) + return + + +def mla_forward_fake( + hidden_states: torch.Tensor, + need_gather_q_kv: bool, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="mla_forward", + op_func=mla_forward, + mutates_args=["output"], + fake_impl=mla_forward_fake, + dispatch_key="PrivateUse1", +) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 130f0f4..8114f4e 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -227,8 +227,9 @@ class NPUPlatform(Platform): "When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE" compilation_config.set_splitting_ops_for_v1() compilation_config.use_inductor = False - compilation_config.splitting_ops.extend( - ["vllm.unified_ascend_attention_with_output"]) + compilation_config.splitting_ops.extend([ + "vllm.unified_ascend_attention_with_output", "vllm.mla_forward" + ]) update_aclgraph_sizes(vllm_config) else: logger.info( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a409bd3..c267879 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -93,6 +93,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import ACLGraphWrapper +from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.logits_processor import build_logitsprocs @@ -412,7 +413,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer - self.mc2_tokens_capacity = 512 * self.parallel_config.tensor_parallel_size + # NOTE: Technically, MC2 can have 512 tokens each rank, but this will consume too much memory. The formula is: + # ((maxBs * tokenNeedSizeDispatch * ep_worldsize * localMoeExpertNum) + (maxBs * tokenNeedSizeCombine * (k + sharedExpertNum))) * 2 + # so we have to limit the MC2 tokens to save memory, should fix this in the future. + self.mc2_tokens_capacity = 512 self.reserved_mc2_mask = torch.zeros( self.mc2_tokens_capacity, dtype=torch.bool, @@ -2811,6 +2815,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): # or enable more requests to be processed simultaneously. self.shared_kv_cache_layers[layer_name] = kv_tgt_layer continue + if isinstance(attn_module, AscendMultiHeadLatentAttention): + continue # TODO: Support other attention modules, e.g., cross-attention # TODO(lucas): move the attention specs into the model layers like