[Performance] Change the shape of kv_cache to avoid view of k_cache and v_cache. (#204)

This PR changes the shape of kv cache to avoid the view of k_cache and
v_cache.
What's more, cache the metadata of k_cache and v_cache to avoid
duplicative slice operations to improve performance.

Signed-off-by: hw_whx <wanghexiang7@huawei.com>
This commit is contained in:
whx
2025-03-05 10:51:07 +08:00
committed by GitHub
parent 562fa673e5
commit 0d3463400a
3 changed files with 28 additions and 23 deletions

View File

@@ -121,7 +121,7 @@ class AscendAttentionBackend(AttentionBackend):
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads * head_size)
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def swap_blocks(
@@ -512,6 +512,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.seq_len_cpu_tensor = None
self.key_cache = None
self.value_cache = None
def forward(
self,
@@ -555,6 +557,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
dtype=query.dtype,
device=query.device)
if kv_cache.numel() > 0:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
if hasattr(layer, 'quant_method'):
isPrefill = True if attn_metadata.num_prefills > 0 else False
if isPrefill:
@@ -570,24 +577,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None
# Details of kv_cache arrangement in attention quantization
# are implemented by quant_method.
layer.quant_method.apply(layer, query, key, value, kv_cache,
self.scale, self.seq_lens_tensor_cpu,
block_tables, isPrefill, attn_metadata,
output)
layer.quant_method.apply(layer, query, key, value, self.key_cache,
self.value_cache, self.scale,
self.seq_lens_tensor_cpu, block_tables,
isPrefill, attn_metadata, output)
else:
if kv_cache.numel() > 0:
key_cache, value_cache = kv_cache[0], kv_cache[1]
num_blocks, block_size, _ = key_cache.shape
key_cache = key_cache.view(num_blocks, block_size,
self.num_kv_heads, self.head_size)
value_cache = value_cache.view(num_blocks, block_size,
self.num_kv_heads,
self.head_size)
slots = attn_metadata.slot_mapping
if self.key_cache is not None:
torch_npu._npu_reshape_and_cache(key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots)
if attn_metadata.num_prefills > 0:
@@ -617,15 +616,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
"Prefix cache and chunked prefill are currently not supported."
)
elif attn_metadata.decode_metadata:
assert kv_cache is not None
assert self.key_cache is not None
self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(attn_metadata.decode_metadata.seq_lens).astype(
np.int32))
block_tables = attn_metadata.decode_metadata.block_tables
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,

View File

@@ -246,11 +246,11 @@ class AscendQKVQuantAttentionMethod(BaseKVCacheMethod):
self.quant_method.process_weights_after_loading(layer)
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
kv_cache: List[torch.Tensor], scale: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, scale: torch.Tensor,
seq_lens_tensor_cpu: int, block_tables: torch.Tensor,
isPrefill: bool, attn_metadata, output) -> torch.Tensor:
return self.quant_method.apply(layer, query, key, value, kv_cache,
scale, seq_lens_tensor_cpu,
return self.quant_method.apply(layer, query, key, value, key_cache,
value_cache, scale, seq_lens_tensor_cpu,
block_tables, isPrefill, attn_metadata,
output)

View File

@@ -266,6 +266,12 @@ class NPUWorker(LocalOrDistributedWorkerBase):
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
import torch_npu
for ve in range(self.parallel_config.pipeline_parallel_size):
num_layers = len(self.cache_engine[ve].gpu_cache)
for i in range(num_layers):
torch_npu.npu_format_cast(self.cache_engine[ve].gpu_cache[i],
2)
self.gpu_cache = [
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)