[Performance] Change the shape of kv_cache to avoid view of k_cache and v_cache. (#204)
This PR changes the shape of kv cache to avoid the view of k_cache and v_cache. What's more, cache the metadata of k_cache and v_cache to avoid duplicative slice operations to improve performance. Signed-off-by: hw_whx <wanghexiang7@huawei.com>
This commit is contained in:
@@ -121,7 +121,7 @@ class AscendAttentionBackend(AttentionBackend):
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
) -> Tuple[int, ...]:
|
) -> Tuple[int, ...]:
|
||||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def swap_blocks(
|
def swap_blocks(
|
||||||
@@ -512,6 +512,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
self.seq_len_cpu_tensor = None
|
self.seq_len_cpu_tensor = None
|
||||||
|
self.key_cache = None
|
||||||
|
self.value_cache = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -555,6 +557,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
dtype=query.dtype,
|
dtype=query.dtype,
|
||||||
device=query.device)
|
device=query.device)
|
||||||
|
|
||||||
|
if kv_cache.numel() > 0:
|
||||||
|
if self.key_cache is None:
|
||||||
|
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||||
|
slots = attn_metadata.slot_mapping
|
||||||
|
|
||||||
if hasattr(layer, 'quant_method'):
|
if hasattr(layer, 'quant_method'):
|
||||||
isPrefill = True if attn_metadata.num_prefills > 0 else False
|
isPrefill = True if attn_metadata.num_prefills > 0 else False
|
||||||
if isPrefill:
|
if isPrefill:
|
||||||
@@ -570,24 +577,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None
|
block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None
|
||||||
# Details of kv_cache arrangement in attention quantization
|
# Details of kv_cache arrangement in attention quantization
|
||||||
# are implemented by quant_method.
|
# are implemented by quant_method.
|
||||||
layer.quant_method.apply(layer, query, key, value, kv_cache,
|
layer.quant_method.apply(layer, query, key, value, self.key_cache,
|
||||||
self.scale, self.seq_lens_tensor_cpu,
|
self.value_cache, self.scale,
|
||||||
block_tables, isPrefill, attn_metadata,
|
self.seq_lens_tensor_cpu, block_tables,
|
||||||
output)
|
isPrefill, attn_metadata, output)
|
||||||
else:
|
else:
|
||||||
if kv_cache.numel() > 0:
|
if self.key_cache is not None:
|
||||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
|
||||||
num_blocks, block_size, _ = key_cache.shape
|
|
||||||
key_cache = key_cache.view(num_blocks, block_size,
|
|
||||||
self.num_kv_heads, self.head_size)
|
|
||||||
value_cache = value_cache.view(num_blocks, block_size,
|
|
||||||
self.num_kv_heads,
|
|
||||||
self.head_size)
|
|
||||||
slots = attn_metadata.slot_mapping
|
|
||||||
torch_npu._npu_reshape_and_cache(key=key,
|
torch_npu._npu_reshape_and_cache(key=key,
|
||||||
value=value,
|
value=value,
|
||||||
key_cache=key_cache,
|
key_cache=self.key_cache,
|
||||||
value_cache=value_cache,
|
value_cache=self.value_cache,
|
||||||
slot_indices=slots)
|
slot_indices=slots)
|
||||||
|
|
||||||
if attn_metadata.num_prefills > 0:
|
if attn_metadata.num_prefills > 0:
|
||||||
@@ -617,15 +616,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
"Prefix cache and chunked prefill are currently not supported."
|
"Prefix cache and chunked prefill are currently not supported."
|
||||||
)
|
)
|
||||||
elif attn_metadata.decode_metadata:
|
elif attn_metadata.decode_metadata:
|
||||||
assert kv_cache is not None
|
assert self.key_cache is not None
|
||||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||||
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
||||||
np.int32))
|
np.int32))
|
||||||
block_tables = attn_metadata.decode_metadata.block_tables
|
block_tables = attn_metadata.decode_metadata.block_tables
|
||||||
torch_npu._npu_paged_attention(
|
torch_npu._npu_paged_attention(
|
||||||
query=query,
|
query=query,
|
||||||
key_cache=key_cache,
|
key_cache=self.key_cache,
|
||||||
value_cache=value_cache,
|
value_cache=self.value_cache,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
scale_value=self.scale,
|
scale_value=self.scale,
|
||||||
|
|||||||
@@ -246,11 +246,11 @@ class AscendQKVQuantAttentionMethod(BaseKVCacheMethod):
|
|||||||
self.quant_method.process_weights_after_loading(layer)
|
self.quant_method.process_weights_after_loading(layer)
|
||||||
|
|
||||||
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
||||||
key: torch.Tensor, value: torch.Tensor,
|
key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor,
|
||||||
kv_cache: List[torch.Tensor], scale: torch.Tensor,
|
value_cache: torch.Tensor, scale: torch.Tensor,
|
||||||
seq_lens_tensor_cpu: int, block_tables: torch.Tensor,
|
seq_lens_tensor_cpu: int, block_tables: torch.Tensor,
|
||||||
isPrefill: bool, attn_metadata, output) -> torch.Tensor:
|
isPrefill: bool, attn_metadata, output) -> torch.Tensor:
|
||||||
return self.quant_method.apply(layer, query, key, value, kv_cache,
|
return self.quant_method.apply(layer, query, key, value, key_cache,
|
||||||
scale, seq_lens_tensor_cpu,
|
value_cache, scale, seq_lens_tensor_cpu,
|
||||||
block_tables, isPrefill, attn_metadata,
|
block_tables, isPrefill, attn_metadata,
|
||||||
output)
|
output)
|
||||||
|
|||||||
@@ -266,6 +266,12 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
self.parallel_config, self.device_config)
|
self.parallel_config, self.device_config)
|
||||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||||
]
|
]
|
||||||
|
import torch_npu
|
||||||
|
for ve in range(self.parallel_config.pipeline_parallel_size):
|
||||||
|
num_layers = len(self.cache_engine[ve].gpu_cache)
|
||||||
|
for i in range(num_layers):
|
||||||
|
torch_npu.npu_format_cast(self.cache_engine[ve].gpu_cache[i],
|
||||||
|
2)
|
||||||
self.gpu_cache = [
|
self.gpu_cache = [
|
||||||
self.cache_engine[ve].gpu_cache
|
self.cache_engine[ve].gpu_cache
|
||||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||||
|
|||||||
Reference in New Issue
Block a user