support deepseek quant & mix-parallel with graphmode (#585)

### What this PR does / why we need it?
1. support deepseek with w8a8 quant;
2. support deepseek with mix-parallel(multi-DP, EP+TP);
3. support deepseek with graphmode.
---------

Signed-off-by: wen-jie666 <wenjie39@huawei.com>
Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com>
Signed-off-by: libaokui <libaokui@huawei.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: wen-jie666 <wenjie39@huawei.com>
This commit is contained in:
zzzzwwjj
2025-04-23 16:23:25 +08:00
committed by GitHub
parent e74331a1ed
commit 5c6d05a59e
13 changed files with 520 additions and 221 deletions

View File

@@ -27,6 +27,7 @@ try:
except ImportError:
print("Failed to import torch_npu.")
import torchair._contrib.custom_torch_ops # type: ignore # noqa: F401
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType,
@@ -36,9 +37,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.config import get_current_vllm_config
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
from vllm_ascend.worker.model_runner import (
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
@@ -913,6 +914,12 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
self.w_kc = None
self.w_vc = None
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
def exec_kv(
self,
hidden_states: torch.Tensor,
@@ -1084,7 +1091,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
self.num_heads, -1)
# TODO: Replace the env with more flexible expressions
if VLLM_ENABLE_GRAPH_MODE == '1':
if self.enable_graph_mode:
if len(kv_cache) > 0 and kv_cache[0].numel(
) > 0 and attn_metadata.num_prefills > 0:
slots = attn_metadata.slot_mapping
@@ -1141,7 +1148,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
)
elif attn_metadata.decode_metadata:
assert kv_cache is not None
if VLLM_ENABLE_GRAPH_MODE == '1':
if self.enable_graph_mode:
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)