From 0d3463400a8ae776fc637f4db3a464c0d0dc3da6 Mon Sep 17 00:00:00 2001 From: whx <56632993+whx-sjtu@users.noreply.github.com> Date: Wed, 5 Mar 2025 10:51:07 +0800 Subject: [PATCH] [Performance] Change the shape of kv_cache to avoid view of k_cache and v_cache. (#204) This PR changes the shape of kv cache to avoid the view of k_cache and v_cache. What's more, cache the metadata of k_cache and v_cache to avoid duplicative slice operations to improve performance. Signed-off-by: hw_whx --- vllm_ascend/attention.py | 37 ++++++++++++------------ vllm_ascend/quantization/quant_config.py | 8 ++--- vllm_ascend/worker/worker.py | 6 ++++ 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/vllm_ascend/attention.py b/vllm_ascend/attention.py index 3b1eb2b..4cc9301 100644 --- a/vllm_ascend/attention.py +++ b/vllm_ascend/attention.py @@ -121,7 +121,7 @@ class AscendAttentionBackend(AttentionBackend): num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return (2, num_blocks, block_size, num_kv_heads * head_size) + return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -512,6 +512,8 @@ class AscendAttentionBackendImpl(AttentionImpl): assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.seq_len_cpu_tensor = None + self.key_cache = None + self.value_cache = None def forward( self, @@ -555,6 +557,11 @@ class AscendAttentionBackendImpl(AttentionImpl): dtype=query.dtype, device=query.device) + if kv_cache.numel() > 0: + if self.key_cache is None: + self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] + slots = attn_metadata.slot_mapping + if hasattr(layer, 'quant_method'): isPrefill = True if attn_metadata.num_prefills > 0 else False if isPrefill: @@ -570,24 +577,16 @@ class AscendAttentionBackendImpl(AttentionImpl): block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None # Details of kv_cache arrangement in attention quantization # are implemented by quant_method. - layer.quant_method.apply(layer, query, key, value, kv_cache, - self.scale, self.seq_lens_tensor_cpu, - block_tables, isPrefill, attn_metadata, - output) + layer.quant_method.apply(layer, query, key, value, self.key_cache, + self.value_cache, self.scale, + self.seq_lens_tensor_cpu, block_tables, + isPrefill, attn_metadata, output) else: - if kv_cache.numel() > 0: - key_cache, value_cache = kv_cache[0], kv_cache[1] - num_blocks, block_size, _ = key_cache.shape - key_cache = key_cache.view(num_blocks, block_size, - self.num_kv_heads, self.head_size) - value_cache = value_cache.view(num_blocks, block_size, - self.num_kv_heads, - self.head_size) - slots = attn_metadata.slot_mapping + if self.key_cache is not None: torch_npu._npu_reshape_and_cache(key=key, value=value, - key_cache=key_cache, - value_cache=value_cache, + key_cache=self.key_cache, + value_cache=self.value_cache, slot_indices=slots) if attn_metadata.num_prefills > 0: @@ -617,15 +616,15 @@ class AscendAttentionBackendImpl(AttentionImpl): "Prefix cache and chunked prefill are currently not supported." ) elif attn_metadata.decode_metadata: - assert kv_cache is not None + assert self.key_cache is not None self.seq_lens_tensor_cpu = torch.from_numpy( np.array(attn_metadata.decode_metadata.seq_lens).astype( np.int32)) block_tables = attn_metadata.decode_metadata.block_tables torch_npu._npu_paged_attention( query=query, - key_cache=key_cache, - value_cache=value_cache, + key_cache=self.key_cache, + value_cache=self.value_cache, num_kv_heads=self.num_kv_heads, num_heads=self.num_heads, scale_value=self.scale, diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 35d767b..7fb2622 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -246,11 +246,11 @@ class AscendQKVQuantAttentionMethod(BaseKVCacheMethod): self.quant_method.process_weights_after_loading(layer) def apply(self, layer: torch.nn.Module, query: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - kv_cache: List[torch.Tensor], scale: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, + value_cache: torch.Tensor, scale: torch.Tensor, seq_lens_tensor_cpu: int, block_tables: torch.Tensor, isPrefill: bool, attn_metadata, output) -> torch.Tensor: - return self.quant_method.apply(layer, query, key, value, kv_cache, - scale, seq_lens_tensor_cpu, + return self.quant_method.apply(layer, query, key, value, key_cache, + value_cache, scale, seq_lens_tensor_cpu, block_tables, isPrefill, attn_metadata, output) diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 829a2ec..0f7256b 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -266,6 +266,12 @@ class NPUWorker(LocalOrDistributedWorkerBase): self.parallel_config, self.device_config) for _ in range(self.parallel_config.pipeline_parallel_size) ] + import torch_npu + for ve in range(self.parallel_config.pipeline_parallel_size): + num_layers = len(self.cache_engine[ve].gpu_cache) + for i in range(num_layers): + torch_npu.npu_format_cast(self.cache_engine[ve].gpu_cache[i], + 2) self.gpu_cache = [ self.cache_engine[ve].gpu_cache for ve in range(self.parallel_config.pipeline_parallel_size)