[Feature] Use reshape_and_cache fused op (#706)
Replace torch function with reshape_and_cache fused op for better performance. The `reshape_and_cache` function wasn't working because it expected torch.int32 tensor, but a torch.int64 tensor was provided. Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
||||
from vllm_ascend.ops.cache import concat_and_cache_mla
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -214,7 +213,7 @@ class AscendMLAMetadataBuilder:
|
||||
block_table = (
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
device, non_blocking=True).long()
|
||||
device, non_blocking=True)
|
||||
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
||||
device, non_blocking=True).long()
|
||||
|
||||
@@ -537,14 +536,14 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
prefill_q_pe.contiguous(), prefill_k_pe)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
concat_and_cache_mla(k_c_normed, k_pe, kv_cache,
|
||||
attn_metadata.slot_mapping.flatten())
|
||||
# TODO: replaced back to ascend ops
|
||||
# key = torch.cat([k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe], dim=2)
|
||||
# torch_npu._npu_reshape_and_cache_siso(
|
||||
# key=key,
|
||||
# key_cache=kv_cache,
|
||||
# slot_indices=attn_metadata.slot_mapping.flatten())
|
||||
key = torch.cat([
|
||||
k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe
|
||||
],
|
||||
dim=2)
|
||||
torch_npu._npu_reshape_and_cache_siso(
|
||||
key=key,
|
||||
key_cache=kv_cache,
|
||||
slot_indices=attn_metadata.slot_mapping.flatten())
|
||||
|
||||
if has_prefill:
|
||||
output[num_decode_tokens:] = self._forward_prefill(
|
||||
|
||||
Reference in New Issue
Block a user