[Feat][Graph] Support DeepSeek with ACL Graph (#2707)

### What this PR does / why we need it?
In memory of #677 , a long overdue milestone. Now DeepSeek V3/R1 should
be OK with ACL Graph.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Working on it.

- vLLM version: v0.10.2
- vLLM main:
68dbde5dbb

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-09-16 17:50:17 +08:00
committed by GitHub
parent 3e60aa5483
commit 88ca8a051c
7 changed files with 64 additions and 42 deletions

View File

@@ -41,9 +41,10 @@ def test_row_parallel_linear(cls, mock_distributed):
assert output[0].shape == (2, 4, 64)
@patch("torch.ops.vllm.mla_forward")
@patch("torch_npu.npu_rms_norm")
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
base_config):
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
mock_distributed, base_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
attn = CustomDeepseekV2MLAAttention(config=base_config,
@@ -64,8 +65,8 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
with patch.object(attn.mla_attn,
"__call__",
return_value=torch.randn(2, 4, 128)):
with pytest.raises(AssertionError):
attn(positions, x)
attn(positions, x)
mock_mla_forward.assert_called_once()
attn = CustomDeepseekV2MLAAttention(config=base_config,
hidden_size=128,

View File

@@ -215,21 +215,6 @@ class TestAscendConfig(TestBase):
test_vllm_config.model_config = fake_model_config
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
# aclgraph + deepseek model
with self.assertRaises(NotImplementedError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
},
"refresh": True
}
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
fake_model_config = ModelConfig(model=model_path)
fake_model_config.hf_config = PretrainedConfig()
fake_model_config.hf_config.model_type = "deepseek"
test_vllm_config.model_config = fake_model_config
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
def test_check_torchair_supported(self):
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),

View File

@@ -218,14 +218,8 @@ def check_ascend_config(vllm_config, enforce_eager):
"it has been disabled automatically.")
# aclgraph case
else:
# aclgraph doesn't work with deepseek model and only qwen model is well tested.
if vllm_config.model_config:
model_type = vllm_config.model_config.hf_config.model_type
if "deepseek" in model_type:
raise NotImplementedError(
"ACL Graph does not support deepseek. Please "
"try torchair graph mode to serve deepseek models on vllm-ascend."
" Or set `enforce_eager=True` to use eager mode.")
if "qwen" not in model_type:
logger.warning(
"ACL Graph is currently experimental. Please "

View File

@@ -50,7 +50,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.deepseek_v2 import \
yarn_get_mscale # noqa: E501
from vllm.model_executor.models.deepseek_v2 import ( # noqa: E501
from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE,
get_spec_layer_idx_from_weight_name)

View File

@@ -25,10 +25,11 @@ from typing import Optional
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.forward_context import get_forward_context
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import direct_register_custom_op
@dataclass
@@ -80,6 +81,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.prefix = prefix
self.mla_attn = Attention(
num_heads=self.num_local_heads,
@@ -107,15 +109,17 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
o_proj=mla_modules.o_proj,
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
forward_context = get_forward_context()
if kv_cache is None:
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
num_tokens = hidden_states.shape[0]
need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
@@ -129,16 +133,47 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
if forward_context.attn_metadata:
attn_metadata = forward_context.attn_metadata[
self.mla_attn.layer_name]
else:
attn_metadata = forward_context.attn_metadata
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
attn_metadata, need_gather_q_kv,
output)
torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output,
self.prefix)
output = output.view(-1, output_shape[-1])
return output
def mla_forward(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if forward_context.attn_metadata:
attn_metadata = forward_context.attn_metadata[self.mla_attn.layer_name]
else:
attn_metadata = forward_context.attn_metadata
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
self.mla_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
need_gather_q_kv, output)
return
def mla_forward_fake(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="mla_forward",
op_func=mla_forward,
mutates_args=["output"],
fake_impl=mla_forward_fake,
dispatch_key="PrivateUse1",
)

View File

@@ -227,8 +227,9 @@ class NPUPlatform(Platform):
"When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE"
compilation_config.set_splitting_ops_for_v1()
compilation_config.use_inductor = False
compilation_config.splitting_ops.extend(
["vllm.unified_ascend_attention_with_output"])
compilation_config.splitting_ops.extend([
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
])
update_aclgraph_sizes(vllm_config)
else:
logger.info(

View File

@@ -93,6 +93,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.logits_processor import build_logitsprocs
@@ -412,7 +413,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
self.mc2_tokens_capacity = 512 * self.parallel_config.tensor_parallel_size
# NOTE: Technically, MC2 can have 512 tokens each rank, but this will consume too much memory. The formula is:
# ((maxBs * tokenNeedSizeDispatch * ep_worldsize * localMoeExpertNum) + (maxBs * tokenNeedSizeCombine * (k + sharedExpertNum))) * 2
# so we have to limit the MC2 tokens to save memory, should fix this in the future.
self.mc2_tokens_capacity = 512
self.reserved_mc2_mask = torch.zeros(
self.mc2_tokens_capacity,
dtype=torch.bool,
@@ -2811,6 +2815,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue
if isinstance(attn_module, AscendMultiHeadLatentAttention):
continue
# TODO: Support other attention modules, e.g., cross-attention
# TODO(lucas): move the attention specs into the model layers like