[Feature] Support DeepSeek for A5 (#7232)
### What this PR does / why we need it?
Add A5 mla operators to support running DeepSeek models on A5.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
Signed-off-by: Li Jiahang <216526138+lijiahang226@users.noreply.github.com>
This commit is contained in:
@@ -40,6 +40,7 @@ from vllm_ascend.compilation.acl_graph import (
|
|||||||
update_draft_graph_params_workspaces,
|
update_draft_graph_params_workspaces,
|
||||||
update_graph_params_workspaces,
|
update_graph_params_workspaces,
|
||||||
)
|
)
|
||||||
|
from vllm_ascend.device.device_op import DeviceOperator
|
||||||
from vllm_ascend.ops.layer_shard_linear import (
|
from vllm_ascend.ops.layer_shard_linear import (
|
||||||
is_hidden_layer,
|
is_hidden_layer,
|
||||||
post_process_after_loading_for_shard_weight_series,
|
post_process_after_loading_for_shard_weight_series,
|
||||||
@@ -1075,12 +1076,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, dtype=q_nope.dtype, device=q_nope.device)
|
kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, dtype=q_nope.dtype, device=q_nope.device)
|
||||||
k_pe = torch.empty(toks, num_heads, rope_dim, dtype=q_nope.dtype, device=q_nope.device)
|
k_pe = torch.empty(toks, num_heads, rope_dim, dtype=q_nope.dtype, device=q_nope.device)
|
||||||
|
|
||||||
torch_npu.atb.npu_paged_cache_load(
|
DeviceOperator.mla_cache_load(
|
||||||
cache_kv_c,
|
cache_kv_c,
|
||||||
cache_k_pe,
|
cache_k_pe,
|
||||||
prefill_metadata.block_table,
|
prefill_metadata.block_table,
|
||||||
context_seq_len_npu,
|
context_seq_len_npu,
|
||||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
prefill_metadata.chunked_context.starts[i],
|
||||||
key=kv_c_normed,
|
key=kv_c_normed,
|
||||||
value=k_pe,
|
value=k_pe,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -171,6 +171,18 @@ class BaseDeviceAdaptor:
|
|||||||
output_dtype=fallback_output_dtype,
|
output_dtype=fallback_output_dtype,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mla_cache_load(cache_kv_c, cache_k_pe, block_table, context_seq_len_npu, seq_starts, key, value):
|
||||||
|
torch_npu.atb.npu_paged_cache_load(
|
||||||
|
cache_kv_c,
|
||||||
|
cache_k_pe,
|
||||||
|
block_table,
|
||||||
|
context_seq_len_npu,
|
||||||
|
seq_starts=seq_starts,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class A5DeviceAdaptor(BaseDeviceAdaptor):
|
class A5DeviceAdaptor(BaseDeviceAdaptor):
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -375,6 +387,18 @@ class A5DeviceAdaptor(BaseDeviceAdaptor):
|
|||||||
**gmm2_kwargs,
|
**gmm2_kwargs,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mla_cache_load(cache_kv_c, cache_k_pe, block_table, context_seq_len_npu, seq_offset, key, value):
|
||||||
|
torch_npu.npu_gather_pa_kv_cache(
|
||||||
|
cache_kv_c,
|
||||||
|
cache_k_pe,
|
||||||
|
block_table,
|
||||||
|
context_seq_len_npu,
|
||||||
|
seq_offset=seq_offset,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_device_adaptor() -> type["BaseDeviceAdaptor"]:
|
def get_device_adaptor() -> type["BaseDeviceAdaptor"]:
|
||||||
ascend_device_type = get_ascend_device_type()
|
ascend_device_type = get_ascend_device_type()
|
||||||
|
|||||||
Reference in New Issue
Block a user