[MLA][Graph] Improve assertion on Graph mode with MLA (#933)
### What this PR does / why we need it?
Improve assertion on Graph mode with MLA.
When running deepseek with graph mode, the fused MLA op only support
`numHeads / numKvHeads ∈ {32, 64, 128}`, thus we improve the assertion
info here to avoid users confused with this.
### Does this PR introduce _any_ user-facing change?
Adjusting tp size is required when running deepseek-v3/r1 with graph
mode. deepseek-v2-lite is not supported in graph mode.
### How was this patch tested?
Test locally as the CI machine could not run V3 due to the HBM limits.
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -113,3 +113,13 @@ In scenarios where NPUs have limited HBM (High Bandwidth Memory) capacity, dynam
|
||||
- **Adjust `--gpu-memory-utilization`**: If unspecified, will use the default value of `0.9`. You can decrease this param to reserve more memory to reduce fragmentation risks. See more note in: [vLLM - Inference and Serving - Engine Arguments](https://docs.vllm.ai/en/latest/serving/engine_args.html#vllm.engine.arg_utils-_engine_args_parser-cacheconfig).
|
||||
|
||||
- **Configure `PYTORCH_NPU_ALLOC_CONF`**: Set this environment variable to optimize NPU memory management. For example, you can `export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True` to enable virtual memory feature to mitigate memory fragmentation caused by frequent dynamic memory size adjustments during runtime, see more note in: [PYTORCH_NPU_ALLOC_CONF](https://www.hiascend.com/document/detail/zh/Pytorch/700/comref/Envvariables/Envir_012.html).
|
||||
|
||||
### 15. Failed to enable NPU graph mode when running DeepSeek?
|
||||
You may encounter the following error if running DeepSeek with NPU graph mode enabled. The allowed number of queries per kv when enabling both MLA and Graph mode only support {32, 64, 128}, **Thus this is not supported for DeepSeek-V2-Lite**, as it only has 16 attention heads. The NPU graph mode support on DeepSeek-V2-Lite will be done in the future.
|
||||
|
||||
And if you're using DeepSeek-V3 or DeepSeek-R1, please make sure after the tensor parallel split, num_heads / num_kv_heads in {32, 64, 128}.
|
||||
|
||||
```bash
|
||||
[rank0]: RuntimeError: EZ9999: Inner Error!
|
||||
[rank0]: EZ9999: [PID: 62938] 2025-05-27-06:52:12.455.807 numHeads / numKvHeads = 8, MLA only support {32, 64, 128}.[FUNC:CheckMlaAttrs][FILE:incre_flash_attention_tiling_check.cc][LINE:1218]
|
||||
```
|
||||
|
||||
@@ -40,6 +40,8 @@ from vllm_ascend.platform import CUSTOM_OP_ENABLED
|
||||
from vllm_ascend.worker.model_runner import (
|
||||
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
||||
|
||||
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
|
||||
|
||||
|
||||
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
|
||||
# Construct lower triangle matrix.
|
||||
@@ -1005,6 +1007,15 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
# TODO: support numHeads / numKvHeads < 16 in MLA kernel
|
||||
if self.torchair_graph_enabled:
|
||||
assert self.num_queries_per_kv in _ALLOWED_NUM_QUERIES_PER_KV, \
|
||||
("The allowed number of queries per kv when enabling both MLA and Graph mode"
|
||||
" only support {32, 64, 128}, Thus this is not supported for DeepSeek-V2-Lite,"
|
||||
" as it only has 16 attention heads. And if you're using DeepSeek-V3 or DeepSeek-R1,"
|
||||
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
|
||||
"{32, 64, 128}.")
|
||||
|
||||
def exec_kv(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
@@ -475,6 +476,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.o_proj = kwargs['o_proj']
|
||||
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
@@ -484,6 +486,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.spec_token_num = speculative_config.num_speculative_tokens
|
||||
assert self.spec_token_num > 0
|
||||
|
||||
# TODO: support numHeads / numKvHeads < 16 in MLA kernel
|
||||
if self.torchair_graph_enabled:
|
||||
assert self.num_queries_per_kv in _ALLOWED_NUM_QUERIES_PER_KV, \
|
||||
("The allowed number of queries per kv when enabling both MLA and Graph mode"
|
||||
" only support {32, 64, 128}, Thus this is not supported for DeepSeek-V2-Lite,"
|
||||
" as it only has 16 attention heads. And if you're using DeepSeek-V3 or DeepSeek-R1,"
|
||||
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
|
||||
"{32, 64, 128}.")
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
|
||||
@@ -119,7 +119,7 @@ class MultiStepWorker(NPUWorker):
|
||||
# execute_model_req
|
||||
assert execute_model_req.last_sampled_token_ids is not None
|
||||
model_input.last_sampled_token_ids = (
|
||||
execute_model_req.last_sampled_token_ids.cuda())
|
||||
execute_model_req.last_sampled_token_ids.npu())
|
||||
model_input.add_sampler_output(
|
||||
SamplerOutput(outputs=[], sampled_token_ids=None),
|
||||
model_input.last_sampled_token_ids)
|
||||
|
||||
Reference in New Issue
Block a user