[HybridKV][Bugfix] Fix Hybrid kvcache sharing bug in same attention type (#3760)
### What this PR does / why we need it?
Part of https://github.com/vllm-project/vllm-ascend/pull/3106
Fix Hybrid kvcache sharing bug in same attention type
Change the `shared_by` logic so that the same attention spec could share
the same buffer instead of allocating more hbm.
After this pr, kvcache memory saved 50% in qwen3-next compared with
before (`self_attn:linear_attn=1:3` in an `attn_group`), and
`gpu_memory_utilization` could increase to `0.8` on Qwen3-Next when
running on A2 64G/card with tp4
<img width="2833" height="1540" alt="image"
src="https://github.com/user-attachments/assets/2a91fa99-fb0f-447c-9e8b-acd587890fbe"
/>
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Test pass with the latest e2e test case on qwen3-next
- vLLM version: v0.11.0rc3
- vLLM main:
c9461e05a4
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -27,12 +27,12 @@ from tests.e2e.conftest import VllmRunner
|
|||||||
def test_models_distributed_Qwen3_NEXT_TP4():
|
def test_models_distributed_Qwen3_NEXT_TP4():
|
||||||
example_prompts = [
|
example_prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
]
|
] * 4
|
||||||
max_tokens = 5
|
max_tokens = 5
|
||||||
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||||
tensor_parallel_size=4,
|
tensor_parallel_size=4,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
gpu_memory_utilization=0.7,
|
gpu_memory_utilization=0.8,
|
||||||
distributed_executor_backend="mp",
|
distributed_executor_backend="mp",
|
||||||
enforce_eager=True) as vllm_model:
|
enforce_eager=True) as vllm_model:
|
||||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|||||||
@@ -3225,25 +3225,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# TODO: REFACTOR ME to sharing hybrid cache
|
# TODO: REFACTOR ME to sharing hybrid cache
|
||||||
for idx in range(len(kv_cache_tensor.shared_by)):
|
for idx in range(len(kv_cache_tensor.shared_by)):
|
||||||
layer_name = kv_cache_tensor.shared_by[idx]
|
layer_name = kv_cache_tensor.shared_by[idx]
|
||||||
if "linear_attn" in layer_name:
|
if "linear_attn" in layer_name and layer_name not in kv_cache_raw_tensors.keys(
|
||||||
|
):
|
||||||
# for mamba linear attention
|
# for mamba linear attention
|
||||||
|
if self.vllm_config.kv_transfer_config is None:
|
||||||
|
tensor = torch.zeros(kv_cache_tensor.size,
|
||||||
|
dtype=torch.int8,
|
||||||
|
device=self.device)
|
||||||
|
else:
|
||||||
|
cache_size_aligned = kv_cache_tensor.size + alignment
|
||||||
|
tensor = torch.zeros(cache_size_aligned,
|
||||||
|
dtype=torch.int8,
|
||||||
|
device=self.device)
|
||||||
|
tensor = self._align_memory(
|
||||||
|
tensor, alignment)[:kv_cache_tensor.size]
|
||||||
for layer_name_inner in kv_cache_tensor.shared_by:
|
for layer_name_inner in kv_cache_tensor.shared_by:
|
||||||
if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \
|
# shared the kvcache between the linear_attn specs in the same group
|
||||||
layer_name_inner in kv_cache_raw_tensors.keys():
|
if "linear_attn" in layer_name_inner:
|
||||||
continue
|
kv_cache_raw_tensors[layer_name_inner] = tensor
|
||||||
if self.vllm_config.kv_transfer_config is None:
|
elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors.keys(
|
||||||
tensor = torch.zeros(kv_cache_tensor.size,
|
):
|
||||||
dtype=torch.int8,
|
|
||||||
device=self.device)
|
|
||||||
else:
|
|
||||||
cache_size_aligned = kv_cache_tensor.size + alignment
|
|
||||||
tensor = torch.zeros(cache_size_aligned,
|
|
||||||
dtype=torch.int8,
|
|
||||||
device=self.device)
|
|
||||||
tensor = self._align_memory(
|
|
||||||
tensor, alignment)[:kv_cache_tensor.size]
|
|
||||||
kv_cache_raw_tensors[layer_name_inner] = tensor
|
|
||||||
elif "attn" in layer_name:
|
|
||||||
# for other attentions, e.g., self_attn, sliding window attn
|
# for other attentions, e.g., self_attn, sliding window attn
|
||||||
if self.vllm_config.kv_transfer_config is None:
|
if self.vllm_config.kv_transfer_config is None:
|
||||||
k_tensor = torch.zeros(kv_cache_tensor.size // 2,
|
k_tensor = torch.zeros(kv_cache_tensor.size // 2,
|
||||||
@@ -3265,7 +3266,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
alignment)[:cache_size]
|
alignment)[:cache_size]
|
||||||
v_tensor = self._align_memory(v_tensor,
|
v_tensor = self._align_memory(v_tensor,
|
||||||
alignment)[:cache_size]
|
alignment)[:cache_size]
|
||||||
kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor)
|
for layer_name_inner in kv_cache_tensor.shared_by:
|
||||||
|
# shared the kvcache between the self_attn specs in the same group
|
||||||
|
if ("attn" in layer_name_inner
|
||||||
|
and "linear_attn" not in layer_name_inner):
|
||||||
|
kv_cache_raw_tensors[layer_name_inner] = (k_tensor,
|
||||||
|
v_tensor)
|
||||||
|
|
||||||
layer_names = set()
|
layer_names = set()
|
||||||
for group in kv_cache_config.kv_cache_groups:
|
for group in kv_cache_config.kv_cache_groups:
|
||||||
|
|||||||
Reference in New Issue
Block a user