[Feat][Graph] Support DeepSeek with ACL Graph (#2707)

### What this PR does / why we need it?
In memory of #677 , a long overdue milestone. Now DeepSeek V3/R1 should
be OK with ACL Graph.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Working on it.

- vLLM version: v0.10.2
- vLLM main:
68dbde5dbb

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-09-16 17:50:17 +08:00
committed by GitHub
parent 3e60aa5483
commit 88ca8a051c
7 changed files with 64 additions and 42 deletions

View File

@@ -25,10 +25,11 @@ from typing import Optional
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.forward_context import get_forward_context
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import direct_register_custom_op
@dataclass
@@ -80,6 +81,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.prefix = prefix
self.mla_attn = Attention(
num_heads=self.num_local_heads,
@@ -107,15 +109,17 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
o_proj=mla_modules.o_proj,
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
forward_context = get_forward_context()
if kv_cache is None:
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
num_tokens = hidden_states.shape[0]
need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
@@ -129,16 +133,47 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
if forward_context.attn_metadata:
attn_metadata = forward_context.attn_metadata[
self.mla_attn.layer_name]
else:
attn_metadata = forward_context.attn_metadata
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
attn_metadata, need_gather_q_kv,
output)
torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output,
self.prefix)
output = output.view(-1, output_shape[-1])
return output
def mla_forward(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if forward_context.attn_metadata:
attn_metadata = forward_context.attn_metadata[self.mla_attn.layer_name]
else:
attn_metadata = forward_context.attn_metadata
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
self.mla_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
need_gather_q_kv, output)
return
def mla_forward_fake(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="mla_forward",
op_func=mla_forward,
mutates_args=["output"],
fake_impl=mla_forward_fake,
dispatch_key="PrivateUse1",
)