Revert [KV-Sharing] Support KV-Sharing feature in CLA models (#4138) (#5317)

### What this PR does / why we need it?
Revert [KV-Sharing] Support KV-Sharing feature in CLA models (#4138) as
it causes deepseek v3.2 hang error


- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-12-24 22:24:17 +08:00
committed by GitHub
parent fb3d6ca08c
commit e54630e01c
5 changed files with 18 additions and 96 deletions

View File

@@ -307,7 +307,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
device="npu")
self.alibi_slopes = alibi_slopes
self.attn_type = attn_type
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@@ -619,26 +618,24 @@ class AscendAttentionBackendImpl(AttentionImpl):
if len(kv_cache) > 1:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
if self.kv_sharing_target_layer_name is None:
slots = attn_metadata.slot_mapping
if get_ascend_device_type() == AscendDeviceType.A5:
# TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping.
# Should check if the 0 dim of slot_mapping must equal to the 0 dim of key.
# If it's necessary, the slots should be sliced.
torch_npu.npu_scatter_pa_kv_cache(
key=key[:attn_metadata.num_actual_tokens],
value=value[:attn_metadata.
num_actual_tokens].contiguous(),
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_mapping=slots)
else:
torch_npu._npu_reshape_and_cache(
key=key[:attn_metadata.num_actual_tokens],
value=value[:attn_metadata.num_actual_tokens],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots[:attn_metadata.num_actual_tokens])
slots = attn_metadata.slot_mapping
if get_ascend_device_type() == AscendDeviceType.A5:
# TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping.
# Should check if the 0 dim of slot_mapping must equal to the 0 dim of key.
# If it's necessary, the slots should be sliced.
torch_npu.npu_scatter_pa_kv_cache(
key=key[:attn_metadata.num_actual_tokens],
value=value[:attn_metadata.num_actual_tokens].contiguous(),
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_mapping=slots)
else:
torch_npu._npu_reshape_and_cache(
key=key[:attn_metadata.num_actual_tokens],
value=value[:attn_metadata.num_actual_tokens],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots[:attn_metadata.num_actual_tokens])
return key, value
def forward_impl(