[HybridKV] Fix prefill disaggregation kvcache addr alignment & use hybrid kv cache only when running qwen3_next (#3007)

### What this PR does / why we need it?
This pr fixes a few issues on prefill disaggregation:
1. Fix prefill disaggregation kvcache addr alignment issue, llmdatadist
needs the addr of tensors to be aligned with 2M
2. Fix prefill disaggregation kvcache shape error, llmdatadist requires
k/v tensors with shape [num_blocks, ...], however the implentment before
this pr is [2, num_blocks, ...], which will break prefill disaggregation
3. Use hybrid kv cache only when running qwen3_next to fix accuracy
issue on prefill disaggregation.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
Tested locally by @liziyu179 

- vLLM version: v0.10.2
- vLLM main:
4f02b77de4

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-09-18 21:43:22 +08:00
committed by GitHub
parent acb46f303f
commit 367edff5af
3 changed files with 95 additions and 46 deletions

View File

@@ -25,7 +25,6 @@ from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import PrefixStore
from vllm.logger import logger
from vllm.platforms import Platform, PlatformEnum
from vllm.utils import cdiv
from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
init_ascend_config)
@@ -247,10 +246,6 @@ class NPUPlatform(Platform):
if cache_config:
if cache_config.block_size is None:
cache_config.block_size = 128
else:
if not vllm_config.model_config.is_deepseek_mla:
cache_config.block_size = cdiv(cache_config.block_size,
64) * 64
if cache_config.enable_prefix_caching and cache_config.block_size != 128:
logger.warning(

View File

@@ -167,8 +167,7 @@ class BlockTable:
mask, slot_mapping, -1)
else:
assert self.kernel_sizes is not None
if self.block_size == self.kernel_sizes[0] or self.kernel_sizes[
0] == 0:
if self.block_size == self.kernel_sizes[0]:
# IMPORTANT: In hybrid mode, positions are in logical block space,
# but we need to map them to the correct logical block table indices
logical_block_idx = positions // self.block_size

View File

@@ -280,6 +280,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
# use_hybrid_blocks: if hybrid blocks is used.
self.use_hybrid_blocks: bool = False
self.is_multimodal_model = self.model_config.is_multimodal_model
self.is_pooling_model = self.model_config.pooler_config is not None
@@ -2440,8 +2442,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
"""
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
self.may_reinitialize_input_batch(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
self.may_reinitialize_input_batch(kv_cache_config)
if self.model_config.is_deepseek_mla:
kv_caches = self.initialize_kv_cache_tensors_deepseek(
@@ -2452,6 +2455,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
def _align_memory(self, tensor: torch.Tensor,
alignment: int) -> torch.Tensor:
data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset):]
def initialize_kv_cache_tensors_deepseek(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
kv_cache_sizes = {}
@@ -2461,12 +2471,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
"NPU.")
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset):]
kv_caches: Dict[str, torch.Tensor] = {}
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator(
):
@@ -2529,10 +2533,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
rope_cache = torch.zeros(rope_allocate_shape_alignment,
dtype=dtype,
device=self.device)
nope_cache = align_memory(
nope_cache = self._align_memory(
nope_cache,
alignment)[:nope_allocate_shape].view(nope_cache_shape)
rope_cache = align_memory(
rope_cache = self._align_memory(
rope_cache,
alignment)[:rope_allocate_shape].view(rope_cache_shape)
kv_caches[layer_name] = (nope_cache, rope_cache)
@@ -2555,7 +2559,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
corresponding memory buffer for KV cache.
"""
# init kv cache tensors
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
kv_cache_raw_tensors: dict[str, Union[torch.Tensor,
Optional[torch.Tensor]]] = {}
# llmdatadist need the addr of cache tensor be aligned with 2M
alignment = 2 * 1024 * 1024
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
# TODO: REFACTOR ME to sharing hybrid cache
for idx in range(len(kv_cache_tensor.shared_by)):
@@ -2565,15 +2572,40 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if "self_attn" in layer_name_inner or layer_name_inner in kv_cache_raw_tensors.keys(
):
continue
tensor = torch.zeros(kv_cache_tensor.size,
dtype=torch.int8,
device=self.device)
if self.vllm_config.kv_transfer_config is None:
tensor = torch.zeros(kv_cache_tensor.size,
dtype=torch.int8,
device=self.device)
else:
cache_size_aligned = kv_cache_tensor.size + alignment
tensor = torch.zeros(cache_size_aligned,
dtype=torch.int8,
device=self.device)
tensor = self._align_memory(
tensor, alignment)[:kv_cache_tensor.size]
kv_cache_raw_tensors[layer_name_inner] = tensor
elif "self_attn" in layer_name:
tensor = torch.zeros(kv_cache_tensor.size,
dtype=torch.int8,
device=self.device)
kv_cache_raw_tensors[layer_name] = tensor
if self.vllm_config.kv_transfer_config is None:
k_tensor = torch.zeros(kv_cache_tensor.size // 2,
dtype=torch.int8,
device=self.device)
v_tensor = torch.zeros(kv_cache_tensor.size // 2,
dtype=torch.int8,
device=self.device)
else:
cache_size = kv_cache_tensor.size // 2
cache_size_aligned = kv_cache_tensor.size // 2 + alignment
k_tensor = torch.zeros(cache_size_aligned,
dtype=torch.int8,
device=self.device)
v_tensor = torch.zeros(cache_size_aligned,
dtype=torch.int8,
device=self.device)
k_tensor = self._align_memory(k_tensor,
alignment)[:cache_size]
v_tensor = self._align_memory(v_tensor,
alignment)[:cache_size]
kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor)
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
@@ -2591,24 +2623,28 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for layer_name in kv_cache_group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel(
) // kv_cache_spec.page_size_bytes
# `num_blocks` is the number of blocks the model runner can use.
# `kv_cache_config.num_blocks` is the number of blocks that
# KVCacheManager may allocate.
# Since different GPUs may have different number of layers and
# different memory capacities, `num_blocks` can be different on
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
# encounter OOM issue
if isinstance(kv_cache_spec, FullAttentionSpec):
raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore
layer_name]
assert raw_k_tensor is not None
assert raw_v_tensor is not None
assert (raw_k_tensor.numel() + raw_v_tensor.numel()
) % kv_cache_spec.page_size_bytes == 0
num_blocks = (raw_k_tensor.numel() + raw_v_tensor.numel()
) // kv_cache_spec.page_size_bytes
# `num_blocks` is the number of blocks the model runner can use.
# `kv_cache_config.num_blocks` is the number of blocks that
# KVCacheManager may allocate.
# Since different GPUs may have different number of layers and
# different memory capacities, `num_blocks` can be different on
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
if self.vllm_config.additional_config.get(
"kv_cache_dtype", None) == 'int8':
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
@@ -2616,8 +2652,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size)
elif hasattr(attn_backend, "get_supported_block_size"
) and not self.model_config.is_deepseek_mla:
) and self.use_hybrid_blocks:
block_size = attn_backend.get_supported_block_size()[0]
block_size_chunk = kv_cache_spec.block_size // block_size
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks * block_size_chunk, block_size,
@@ -2629,11 +2666,28 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
kv_cache = raw_tensor.view(dtype).view(kv_cache_shape)
kv_cache = self._convert_torch_format(kv_cache)
kv_caches[layer_name] = kv_cache
k_cache = raw_k_tensor.view(dtype).view(kv_cache_shape[1:])
k_cache = self._convert_torch_format(k_cache)
v_cache = raw_v_tensor.view(dtype).view(kv_cache_shape[1:])
v_cache = self._convert_torch_format(v_cache)
kv_caches[layer_name] = (k_cache, v_cache)
elif isinstance(kv_cache_spec, MambaSpec):
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor is not None
assert raw_tensor.numel(
) % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel(
) // kv_cache_spec.page_size_bytes
# `num_blocks` is the number of blocks the model runner can use.
# `kv_cache_config.num_blocks` is the number of blocks that
# KVCacheManager may allocate.
# Since different GPUs may have different number of layers and
# different memory capacities, `num_blocks` can be different on
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
state_tensors = []
storage_offset_bytes = 0
for (shape, dtype) in zip(kv_cache_spec.shapes,
@@ -2702,7 +2756,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
attn_groups = self.attn_groups[kv_cache_group_id]
except IndexError:
attn_groups = None
if attn_groups:
if attn_groups and self.use_hybrid_blocks:
# Use the backend's supported block size list
backend = attn_groups[0].backend
supported_sizes = backend.get_supported_block_size()
@@ -2713,13 +2767,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
[self.cache_config.block_size])
else:
# Fallback to cache config block_size if no backend found
kernel_block_size_list = [
64
] if not self.model_config.is_deepseek_mla else [0]
kernel_block_size_list = [self.cache_config.block_size]
kernel_block_sizes.append(kernel_block_size_list)
else:
# This is likely Mamba or other non-attention cache,
# no splitting.
# NOTE: set kernel_block_sizes to 0 to disable slotmapping computation
# of mamba block. In this case, BlockTable.block_size will never equal
# to kernel_block_sizes[0]
kernel_block_sizes.append([0])
if kernel_block_sizes != [self.cache_config.block_size]:
assert self.cache_config.cpu_offload_gb == 0, (