[Feature] Support DeepSeek for A5 (#7232)

### What this PR does / why we need it?

Add A5 mla operators to support running DeepSeek models on A5.

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

Signed-off-by: Li Jiahang <216526138+lijiahang226@users.noreply.github.com>
This commit is contained in:
lijiahang226
2026-03-23 20:28:26 +08:00
committed by GitHub
parent 13397e9cb7
commit 170dcbda62
2 changed files with 27 additions and 2 deletions

View File

@@ -40,6 +40,7 @@ from vllm_ascend.compilation.acl_graph import (
update_draft_graph_params_workspaces,
update_graph_params_workspaces,
)
from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.ops.layer_shard_linear import (
is_hidden_layer,
post_process_after_loading_for_shard_weight_series,
@@ -1075,12 +1076,12 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, dtype=q_nope.dtype, device=q_nope.device)
k_pe = torch.empty(toks, num_heads, rope_dim, dtype=q_nope.dtype, device=q_nope.device)
torch_npu.atb.npu_paged_cache_load(
DeviceOperator.mla_cache_load(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
context_seq_len_npu,
seq_starts=prefill_metadata.chunked_context.starts[i],
prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
)

View File

@@ -171,6 +171,18 @@ class BaseDeviceAdaptor:
output_dtype=fallback_output_dtype,
)[0]
@staticmethod
def mla_cache_load(cache_kv_c, cache_k_pe, block_table, context_seq_len_npu, seq_starts, key, value):
torch_npu.atb.npu_paged_cache_load(
cache_kv_c,
cache_k_pe,
block_table,
context_seq_len_npu,
seq_starts=seq_starts,
key=key,
value=value,
)
class A5DeviceAdaptor(BaseDeviceAdaptor):
@classmethod
@@ -375,6 +387,18 @@ class A5DeviceAdaptor(BaseDeviceAdaptor):
**gmm2_kwargs,
)[0]
@staticmethod
def mla_cache_load(cache_kv_c, cache_k_pe, block_table, context_seq_len_npu, seq_offset, key, value):
torch_npu.npu_gather_pa_kv_cache(
cache_kv_c,
cache_k_pe,
block_table,
context_seq_len_npu,
seq_offset=seq_offset,
key=key,
value=value,
)
def get_device_adaptor() -> type["BaseDeviceAdaptor"]:
ascend_device_type = get_ascend_device_type()