[Feature] Use reshape_and_cache fused op (#706)

Replace torch function with reshape_and_cache fused op for better
performance. The `reshape_and_cache` function wasn't working because it
expected torch.int32 tensor, but a torch.int64 tensor was provided.

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2025-04-28 21:54:42 +08:00
committed by GitHub
parent d39855b075
commit 40bd602485

View File

@@ -13,7 +13,6 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.ops.cache import concat_and_cache_mla
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
if TYPE_CHECKING:
@@ -214,7 +213,7 @@ class AscendMLAMetadataBuilder:
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()
device, non_blocking=True)
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()
@@ -537,14 +536,14 @@ class AscendMLAImpl(MLAAttentionImpl):
prefill_q_pe.contiguous(), prefill_k_pe)
if kv_cache.numel() > 0:
concat_and_cache_mla(k_c_normed, k_pe, kv_cache,
attn_metadata.slot_mapping.flatten())
# TODO: replaced back to ascend ops
# key = torch.cat([k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe], dim=2)
# torch_npu._npu_reshape_and_cache_siso(
# key=key,
# key_cache=kv_cache,
# slot_indices=attn_metadata.slot_mapping.flatten())
key = torch.cat([
k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe
],
dim=2)
torch_npu._npu_reshape_and_cache_siso(
key=key,
key_cache=kv_cache,
slot_indices=attn_metadata.slot_mapping.flatten())
if has_prefill:
output[num_decode_tokens:] = self._forward_prefill(