[Feat][Graph] Support DeepSeek with ACL Graph (#2707)

### What this PR does / why we need it?
In memory of #677 , a long overdue milestone. Now DeepSeek V3/R1 should
be OK with ACL Graph.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Working on it.

- vLLM version: v0.10.2
- vLLM main:
68dbde5dbb

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-09-16 17:50:17 +08:00
committed by GitHub
parent 3e60aa5483
commit 88ca8a051c
7 changed files with 64 additions and 42 deletions

View File

@@ -41,9 +41,10 @@ def test_row_parallel_linear(cls, mock_distributed):
assert output[0].shape == (2, 4, 64) assert output[0].shape == (2, 4, 64)
@patch("torch.ops.vllm.mla_forward")
@patch("torch_npu.npu_rms_norm") @patch("torch_npu.npu_rms_norm")
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed, def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
base_config): mock_distributed, base_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
attn = CustomDeepseekV2MLAAttention(config=base_config, attn = CustomDeepseekV2MLAAttention(config=base_config,
@@ -64,8 +65,8 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
with patch.object(attn.mla_attn, with patch.object(attn.mla_attn,
"__call__", "__call__",
return_value=torch.randn(2, 4, 128)): return_value=torch.randn(2, 4, 128)):
with pytest.raises(AssertionError): attn(positions, x)
attn(positions, x) mock_mla_forward.assert_called_once()
attn = CustomDeepseekV2MLAAttention(config=base_config, attn = CustomDeepseekV2MLAAttention(config=base_config,
hidden_size=128, hidden_size=128,

View File

@@ -215,21 +215,6 @@ class TestAscendConfig(TestBase):
test_vllm_config.model_config = fake_model_config test_vllm_config.model_config = fake_model_config
init_ascend_config(test_vllm_config) init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False) check_ascend_config(test_vllm_config, False)
# aclgraph + deepseek model
with self.assertRaises(NotImplementedError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
},
"refresh": True
}
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
fake_model_config = ModelConfig(model=model_path)
fake_model_config.hf_config = PretrainedConfig()
fake_model_config.hf_config.model_type = "deepseek"
test_vllm_config.model_config = fake_model_config
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
def test_check_torchair_supported(self): def test_check_torchair_supported(self):
test_cases = [('deepseek_v3', True), ('PanguProMoE', True), test_cases = [('deepseek_v3', True), ('PanguProMoE', True),

View File

@@ -218,14 +218,8 @@ def check_ascend_config(vllm_config, enforce_eager):
"it has been disabled automatically.") "it has been disabled automatically.")
# aclgraph case # aclgraph case
else: else:
# aclgraph doesn't work with deepseek model and only qwen model is well tested.
if vllm_config.model_config: if vllm_config.model_config:
model_type = vllm_config.model_config.hf_config.model_type model_type = vllm_config.model_config.hf_config.model_type
if "deepseek" in model_type:
raise NotImplementedError(
"ACL Graph does not support deepseek. Please "
"try torchair graph mode to serve deepseek models on vllm-ascend."
" Or set `enforce_eager=True` to use eager mode.")
if "qwen" not in model_type: if "qwen" not in model_type:
logger.warning( logger.warning(
"ACL Graph is currently experimental. Please " "ACL Graph is currently experimental. Please "

View File

@@ -50,7 +50,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.deepseek_v2 import \ from vllm.model_executor.models.deepseek_v2 import \
yarn_get_mscale # noqa: E501 yarn_get_mscale # noqa: E501
from vllm.model_executor.models.deepseek_v2 import ( # noqa: E501 from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM, DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE, DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE,
get_spec_layer_idx_from_weight_name) get_spec_layer_idx_from_weight_name)

View File

@@ -25,10 +25,11 @@ from typing import Optional
import torch import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, get_current_vllm_config
from vllm.forward_context import get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.mla import MultiHeadLatentAttention from vllm.model_executor.layers.mla import MultiHeadLatentAttention
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import direct_register_custom_op
@dataclass @dataclass
@@ -80,6 +81,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
self.qk_nope_head_dim = qk_nope_head_dim self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_head_dim self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim self.v_head_dim = v_head_dim
self.prefix = prefix
self.mla_attn = Attention( self.mla_attn = Attention(
num_heads=self.num_local_heads, num_heads=self.num_local_heads,
@@ -107,15 +109,17 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
o_proj=mla_modules.o_proj, o_proj=mla_modules.o_proj,
) )
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None, kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
forward_context = get_forward_context()
if kv_cache is None:
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
num_tokens = hidden_states.shape[0] num_tokens = hidden_states.shape[0]
need_gather_q_kv = False need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
@@ -129,16 +133,47 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
if num_tokens % self.tp_size: if num_tokens % self.tp_size:
rows += 1 rows += 1
output_shape = (rows, hidden_states.shape[1]) output_shape = (rows, hidden_states.shape[1])
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape, output = torch.empty(output_shape,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device=hidden_states.device) device=hidden_states.device)
if forward_context.attn_metadata: torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output,
attn_metadata = forward_context.attn_metadata[ self.prefix)
self.mla_attn.layer_name]
else:
attn_metadata = forward_context.attn_metadata
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
attn_metadata, need_gather_q_kv,
output)
output = output.view(-1, output_shape[-1]) output = output.view(-1, output_shape[-1])
return output return output
def mla_forward(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if forward_context.attn_metadata:
attn_metadata = forward_context.attn_metadata[self.mla_attn.layer_name]
else:
attn_metadata = forward_context.attn_metadata
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
self.mla_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
need_gather_q_kv, output)
return
def mla_forward_fake(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="mla_forward",
op_func=mla_forward,
mutates_args=["output"],
fake_impl=mla_forward_fake,
dispatch_key="PrivateUse1",
)

View File

@@ -227,8 +227,9 @@ class NPUPlatform(Platform):
"When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE" "When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE"
compilation_config.set_splitting_ops_for_v1() compilation_config.set_splitting_ops_for_v1()
compilation_config.use_inductor = False compilation_config.use_inductor = False
compilation_config.splitting_ops.extend( compilation_config.splitting_ops.extend([
["vllm.unified_ascend_attention_with_output"]) "vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
])
update_aclgraph_sizes(vllm_config) update_aclgraph_sizes(vllm_config)
else: else:
logger.info( logger.info(

View File

@@ -93,6 +93,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.platform import NPUPlatform from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.logits_processor import build_logitsprocs from vllm_ascend.sample.logits_processor import build_logitsprocs
@@ -412,7 +413,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
self.mc2_tokens_capacity = 512 * self.parallel_config.tensor_parallel_size # NOTE: Technically, MC2 can have 512 tokens each rank, but this will consume too much memory. The formula is:
# ((maxBs * tokenNeedSizeDispatch * ep_worldsize * localMoeExpertNum) + (maxBs * tokenNeedSizeCombine * (k + sharedExpertNum))) * 2
# so we have to limit the MC2 tokens to save memory, should fix this in the future.
self.mc2_tokens_capacity = 512
self.reserved_mc2_mask = torch.zeros( self.reserved_mc2_mask = torch.zeros(
self.mc2_tokens_capacity, self.mc2_tokens_capacity,
dtype=torch.bool, dtype=torch.bool,
@@ -2811,6 +2815,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# or enable more requests to be processed simultaneously. # or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue continue
if isinstance(attn_module, AscendMultiHeadLatentAttention):
continue
# TODO: Support other attention modules, e.g., cross-attention # TODO: Support other attention modules, e.g., cross-attention
# TODO(lucas): move the attention specs into the model layers like # TODO(lucas): move the attention specs into the model layers like