[MLA][Graph] Improve assertion on Graph mode with MLA (#933)

### What this PR does / why we need it?
Improve assertion on Graph mode with MLA.

When running deepseek with graph mode, the fused MLA op only support
`numHeads / numKvHeads ∈ {32, 64, 128}`, thus we improve the assertion
info here to avoid users confused with this.

### Does this PR introduce _any_ user-facing change?
Adjusting tp size is required when running deepseek-v3/r1 with graph
mode. deepseek-v2-lite is not supported in graph mode.

### How was this patch tested?
Test locally as the CI machine could not run V3 due to the HBM limits.

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-06-10 22:26:53 +08:00
committed by GitHub
parent 291c216898
commit 8dd686dfa2
4 changed files with 33 additions and 1 deletions

View File

@@ -40,6 +40,8 @@ from vllm_ascend.platform import CUSTOM_OP_ENABLED
from vllm_ascend.worker.model_runner import (
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
# Construct lower triangle matrix.
@@ -1005,6 +1007,15 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
# TODO: support numHeads / numKvHeads < 16 in MLA kernel
if self.torchair_graph_enabled:
assert self.num_queries_per_kv in _ALLOWED_NUM_QUERIES_PER_KV, \
("The allowed number of queries per kv when enabling both MLA and Graph mode"
" only support {32, 64, 128}, Thus this is not supported for DeepSeek-V2-Lite,"
" as it only has 16 attention heads. And if you're using DeepSeek-V3 or DeepSeek-R1,"
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
"{32, 64, 128}.")
def exec_kv(
self,
hidden_states: torch.Tensor,

View File

@@ -13,6 +13,7 @@ from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
@@ -475,6 +476,7 @@ class AscendMLAImpl(MLAAttentionImpl):
self.o_proj = kwargs['o_proj']
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
@@ -484,6 +486,15 @@ class AscendMLAImpl(MLAAttentionImpl):
self.spec_token_num = speculative_config.num_speculative_tokens
assert self.spec_token_num > 0
# TODO: support numHeads / numKvHeads < 16 in MLA kernel
if self.torchair_graph_enabled:
assert self.num_queries_per_kv in _ALLOWED_NUM_QUERIES_PER_KV, \
("The allowed number of queries per kv when enabling both MLA and Graph mode"
" only support {32, 64, 128}, Thus this is not supported for DeepSeek-V2-Lite,"
" as it only has 16 attention heads. And if you're using DeepSeek-V3 or DeepSeek-R1,"
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
"{32, 64, 128}.")
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)

View File

@@ -119,7 +119,7 @@ class MultiStepWorker(NPUWorker):
# execute_model_req
assert execute_model_req.last_sampled_token_ids is not None
model_input.last_sampled_token_ids = (
execute_model_req.last_sampled_token_ids.cuda())
execute_model_req.last_sampled_token_ids.npu())
model_input.add_sampler_output(
SamplerOutput(outputs=[], sampled_token_ids=None),
model_input.last_sampled_token_ids)