[Performance] Change the shape of kv_cache to avoid view of k_cache and v_cache. (#204)
This PR changes the shape of kv cache to avoid the view of k_cache and v_cache. What's more, cache the metadata of k_cache and v_cache to avoid duplicative slice operations to improve performance. Signed-off-by: hw_whx <wanghexiang7@huawei.com>
This commit is contained in:
@@ -121,7 +121,7 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
@@ -512,6 +512,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.seq_len_cpu_tensor = None
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -555,6 +557,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
if self.key_cache is None:
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
|
||||
if hasattr(layer, 'quant_method'):
|
||||
isPrefill = True if attn_metadata.num_prefills > 0 else False
|
||||
if isPrefill:
|
||||
@@ -570,24 +577,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None
|
||||
# Details of kv_cache arrangement in attention quantization
|
||||
# are implemented by quant_method.
|
||||
layer.quant_method.apply(layer, query, key, value, kv_cache,
|
||||
self.scale, self.seq_lens_tensor_cpu,
|
||||
block_tables, isPrefill, attn_metadata,
|
||||
output)
|
||||
layer.quant_method.apply(layer, query, key, value, self.key_cache,
|
||||
self.value_cache, self.scale,
|
||||
self.seq_lens_tensor_cpu, block_tables,
|
||||
isPrefill, attn_metadata, output)
|
||||
else:
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
num_blocks, block_size, _ = key_cache.shape
|
||||
key_cache = key_cache.view(num_blocks, block_size,
|
||||
self.num_kv_heads, self.head_size)
|
||||
value_cache = value_cache.view(num_blocks, block_size,
|
||||
self.num_kv_heads,
|
||||
self.head_size)
|
||||
slots = attn_metadata.slot_mapping
|
||||
if self.key_cache is not None:
|
||||
torch_npu._npu_reshape_and_cache(key=key,
|
||||
value=value,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
@@ -617,15 +616,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
"Prefix cache and chunked prefill are currently not supported."
|
||||
)
|
||||
elif attn_metadata.decode_metadata:
|
||||
assert kv_cache is not None
|
||||
assert self.key_cache is not None
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
block_tables = attn_metadata.decode_metadata.block_tables
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
|
||||
@@ -246,11 +246,11 @@ class AscendQKVQuantAttentionMethod(BaseKVCacheMethod):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
kv_cache: List[torch.Tensor], scale: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor, scale: torch.Tensor,
|
||||
seq_lens_tensor_cpu: int, block_tables: torch.Tensor,
|
||||
isPrefill: bool, attn_metadata, output) -> torch.Tensor:
|
||||
return self.quant_method.apply(layer, query, key, value, kv_cache,
|
||||
scale, seq_lens_tensor_cpu,
|
||||
return self.quant_method.apply(layer, query, key, value, key_cache,
|
||||
value_cache, scale, seq_lens_tensor_cpu,
|
||||
block_tables, isPrefill, attn_metadata,
|
||||
output)
|
||||
|
||||
@@ -266,6 +266,12 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
self.parallel_config, self.device_config)
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
import torch_npu
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size):
|
||||
num_layers = len(self.cache_engine[ve].gpu_cache)
|
||||
for i in range(num_layers):
|
||||
torch_npu.npu_format_cast(self.cache_engine[ve].gpu_cache[i],
|
||||
2)
|
||||
self.gpu_cache = [
|
||||
self.cache_engine[ve].gpu_cache
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||
|
||||
Reference in New Issue
Block a user