[KV-Sharing] Support KV-Sharing feature in CLA models (#4138)
### What this PR does / why we need it?
Support KV-Sharing feature in CLA (cross layer attention) models, which
sharing kv cache in some layers.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -307,6 +307,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
device="npu")
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.attn_type = attn_type
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
@@ -618,24 +619,26 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
if len(kv_cache) > 1:
|
||||
if self.key_cache is None:
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
if get_ascend_device_type() == AscendDeviceType.A5:
|
||||
# TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping.
|
||||
# Should check if the 0 dim of slot_mapping must equal to the 0 dim of key.
|
||||
# If it's necessary, the slots should be sliced.
|
||||
torch_npu.npu_scatter_pa_kv_cache(
|
||||
key=key[:attn_metadata.num_actual_tokens],
|
||||
value=value[:attn_metadata.num_actual_tokens].contiguous(),
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_mapping=slots)
|
||||
else:
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=key[:attn_metadata.num_actual_tokens],
|
||||
value=value[:attn_metadata.num_actual_tokens],
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots[:attn_metadata.num_actual_tokens])
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
slots = attn_metadata.slot_mapping
|
||||
if get_ascend_device_type() == AscendDeviceType.A5:
|
||||
# TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping.
|
||||
# Should check if the 0 dim of slot_mapping must equal to the 0 dim of key.
|
||||
# If it's necessary, the slots should be sliced.
|
||||
torch_npu.npu_scatter_pa_kv_cache(
|
||||
key=key[:attn_metadata.num_actual_tokens],
|
||||
value=value[:attn_metadata.
|
||||
num_actual_tokens].contiguous(),
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_mapping=slots)
|
||||
else:
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=key[:attn_metadata.num_actual_tokens],
|
||||
value=value[:attn_metadata.num_actual_tokens],
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots[:attn_metadata.num_actual_tokens])
|
||||
return key, value
|
||||
|
||||
def forward_impl(
|
||||
|
||||
@@ -1195,6 +1195,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens):
|
||||
if self.shared_kv_cache_layers is not None:
|
||||
# sharing kv across layers need to read the kvcache,
|
||||
# directly return chunked prefill in this scenario
|
||||
return AscendAttentionState.ChunkedPrefill
|
||||
if np.array_equal(self.seq_lens.np[:num_reqs], num_scheduled_tokens):
|
||||
attn_state = AscendAttentionState.PrefillNoCache
|
||||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||||
@@ -2243,6 +2247,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
kv_cache_config = deepcopy(kv_cache_config)
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
||||
# NOTE(cmq): initialize_attn_backend must before using self.attn_groups
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
|
||||
@@ -2282,6 +2287,13 @@ class NPUModelRunner(GPUModelRunner):
|
||||
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
|
||||
kv_cache_raw_tensors)
|
||||
|
||||
# Set up cross-layer KV cache sharing
|
||||
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
|
||||
):
|
||||
logger.debug("%s reuses KV cache of %s", layer_name,
|
||||
target_layer_name)
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
|
||||
Reference in New Issue
Block a user