[Feature] Support DeepSeek for A5 (#7232)
### What this PR does / why we need it?
Add A5 mla operators to support running DeepSeek models on A5.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
Signed-off-by: Li Jiahang <216526138+lijiahang226@users.noreply.github.com>
This commit is contained in:
@@ -40,6 +40,7 @@ from vllm_ascend.compilation.acl_graph import (
|
||||
update_draft_graph_params_workspaces,
|
||||
update_graph_params_workspaces,
|
||||
)
|
||||
from vllm_ascend.device.device_op import DeviceOperator
|
||||
from vllm_ascend.ops.layer_shard_linear import (
|
||||
is_hidden_layer,
|
||||
post_process_after_loading_for_shard_weight_series,
|
||||
@@ -1075,12 +1076,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, dtype=q_nope.dtype, device=q_nope.device)
|
||||
k_pe = torch.empty(toks, num_heads, rope_dim, dtype=q_nope.dtype, device=q_nope.device)
|
||||
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
DeviceOperator.mla_cache_load(
|
||||
cache_kv_c,
|
||||
cache_k_pe,
|
||||
prefill_metadata.block_table,
|
||||
context_seq_len_npu,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
prefill_metadata.chunked_context.starts[i],
|
||||
key=kv_c_normed,
|
||||
value=k_pe,
|
||||
)
|
||||
|
||||
@@ -171,6 +171,18 @@ class BaseDeviceAdaptor:
|
||||
output_dtype=fallback_output_dtype,
|
||||
)[0]
|
||||
|
||||
@staticmethod
|
||||
def mla_cache_load(cache_kv_c, cache_k_pe, block_table, context_seq_len_npu, seq_starts, key, value):
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
cache_kv_c,
|
||||
cache_k_pe,
|
||||
block_table,
|
||||
context_seq_len_npu,
|
||||
seq_starts=seq_starts,
|
||||
key=key,
|
||||
value=value,
|
||||
)
|
||||
|
||||
|
||||
class A5DeviceAdaptor(BaseDeviceAdaptor):
|
||||
@classmethod
|
||||
@@ -375,6 +387,18 @@ class A5DeviceAdaptor(BaseDeviceAdaptor):
|
||||
**gmm2_kwargs,
|
||||
)[0]
|
||||
|
||||
@staticmethod
|
||||
def mla_cache_load(cache_kv_c, cache_k_pe, block_table, context_seq_len_npu, seq_offset, key, value):
|
||||
torch_npu.npu_gather_pa_kv_cache(
|
||||
cache_kv_c,
|
||||
cache_k_pe,
|
||||
block_table,
|
||||
context_seq_len_npu,
|
||||
seq_offset=seq_offset,
|
||||
key=key,
|
||||
value=value,
|
||||
)
|
||||
|
||||
|
||||
def get_device_adaptor() -> type["BaseDeviceAdaptor"]:
|
||||
ascend_device_type = get_ascend_device_type()
|
||||
|
||||
Reference in New Issue
Block a user