[Performance] Change the shape of kv_cache to avoid view of k_cache and v_cache. (#204)

This PR changes the shape of kv cache to avoid the view of k_cache and
v_cache.
What's more, cache the metadata of k_cache and v_cache to avoid
duplicative slice operations to improve performance.

Signed-off-by: hw_whx <wanghexiang7@huawei.com>
This commit is contained in:
whx
2025-03-05 10:51:07 +08:00
committed by GitHub
parent 562fa673e5
commit 0d3463400a
3 changed files with 28 additions and 23 deletions

View File

@@ -246,11 +246,11 @@ class AscendQKVQuantAttentionMethod(BaseKVCacheMethod):
self.quant_method.process_weights_after_loading(layer)
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
kv_cache: List[torch.Tensor], scale: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, scale: torch.Tensor,
seq_lens_tensor_cpu: int, block_tables: torch.Tensor,
isPrefill: bool, attn_metadata, output) -> torch.Tensor:
return self.quant_method.apply(layer, query, key, value, kv_cache,
scale, seq_lens_tensor_cpu,
return self.quant_method.apply(layer, query, key, value, key_cache,
value_cache, scale, seq_lens_tensor_cpu,
block_tables, isPrefill, attn_metadata,
output)