[HybridKV] Fix prefill disaggregation kvcache addr alignment & use hybrid kv cache only when running qwen3_next (#3007)
### What this PR does / why we need it?
This pr fixes a few issues on prefill disaggregation:
1. Fix prefill disaggregation kvcache addr alignment issue, llmdatadist
needs the addr of tensors to be aligned with 2M
2. Fix prefill disaggregation kvcache shape error, llmdatadist requires
k/v tensors with shape [num_blocks, ...], however the implentment before
this pr is [2, num_blocks, ...], which will break prefill disaggregation
3. Use hybrid kv cache only when running qwen3_next to fix accuracy
issue on prefill disaggregation.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
Tested locally by @liziyu179
- vLLM version: v0.10.2
- vLLM main:
4f02b77de4
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -25,7 +25,6 @@ from torch.distributed import ProcessGroup
|
|||||||
from torch.distributed.distributed_c10d import PrefixStore
|
from torch.distributed.distributed_c10d import PrefixStore
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.platforms import Platform, PlatformEnum
|
from vllm.platforms import Platform, PlatformEnum
|
||||||
from vllm.utils import cdiv
|
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
|
from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
|
||||||
init_ascend_config)
|
init_ascend_config)
|
||||||
@@ -247,10 +246,6 @@ class NPUPlatform(Platform):
|
|||||||
if cache_config:
|
if cache_config:
|
||||||
if cache_config.block_size is None:
|
if cache_config.block_size is None:
|
||||||
cache_config.block_size = 128
|
cache_config.block_size = 128
|
||||||
else:
|
|
||||||
if not vllm_config.model_config.is_deepseek_mla:
|
|
||||||
cache_config.block_size = cdiv(cache_config.block_size,
|
|
||||||
64) * 64
|
|
||||||
|
|
||||||
if cache_config.enable_prefix_caching and cache_config.block_size != 128:
|
if cache_config.enable_prefix_caching and cache_config.block_size != 128:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -167,8 +167,7 @@ class BlockTable:
|
|||||||
mask, slot_mapping, -1)
|
mask, slot_mapping, -1)
|
||||||
else:
|
else:
|
||||||
assert self.kernel_sizes is not None
|
assert self.kernel_sizes is not None
|
||||||
if self.block_size == self.kernel_sizes[0] or self.kernel_sizes[
|
if self.block_size == self.kernel_sizes[0]:
|
||||||
0] == 0:
|
|
||||||
# IMPORTANT: In hybrid mode, positions are in logical block space,
|
# IMPORTANT: In hybrid mode, positions are in logical block space,
|
||||||
# but we need to map them to the correct logical block table indices
|
# but we need to map them to the correct logical block table indices
|
||||||
logical_block_idx = positions // self.block_size
|
logical_block_idx = positions // self.block_size
|
||||||
|
|||||||
@@ -280,6 +280,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||||
self.cache_config.cache_dtype]
|
self.cache_config.cache_dtype]
|
||||||
|
# use_hybrid_blocks: if hybrid blocks is used.
|
||||||
|
self.use_hybrid_blocks: bool = False
|
||||||
|
|
||||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||||
self.is_pooling_model = self.model_config.pooler_config is not None
|
self.is_pooling_model = self.model_config.pooler_config is not None
|
||||||
@@ -2440,8 +2442,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
"""
|
"""
|
||||||
kv_cache_config = deepcopy(kv_cache_config)
|
kv_cache_config = deepcopy(kv_cache_config)
|
||||||
self.kv_cache_config = kv_cache_config
|
self.kv_cache_config = kv_cache_config
|
||||||
self.may_reinitialize_input_batch(kv_cache_config)
|
|
||||||
self.initialize_attn_backend(kv_cache_config)
|
self.initialize_attn_backend(kv_cache_config)
|
||||||
|
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
|
||||||
|
self.may_reinitialize_input_batch(kv_cache_config)
|
||||||
|
|
||||||
if self.model_config.is_deepseek_mla:
|
if self.model_config.is_deepseek_mla:
|
||||||
kv_caches = self.initialize_kv_cache_tensors_deepseek(
|
kv_caches = self.initialize_kv_cache_tensors_deepseek(
|
||||||
@@ -2452,6 +2455,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if has_kv_transfer_group():
|
if has_kv_transfer_group():
|
||||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||||
|
|
||||||
|
def _align_memory(self, tensor: torch.Tensor,
|
||||||
|
alignment: int) -> torch.Tensor:
|
||||||
|
data_ptr = tensor.data_ptr()
|
||||||
|
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
||||||
|
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||||||
|
return tensor[int(offset):]
|
||||||
|
|
||||||
def initialize_kv_cache_tensors_deepseek(
|
def initialize_kv_cache_tensors_deepseek(
|
||||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||||
kv_cache_sizes = {}
|
kv_cache_sizes = {}
|
||||||
@@ -2461,12 +2471,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
"NPU.")
|
"NPU.")
|
||||||
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
||||||
|
|
||||||
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
|
|
||||||
data_ptr = tensor.data_ptr()
|
|
||||||
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
|
||||||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
|
||||||
return tensor[int(offset):]
|
|
||||||
|
|
||||||
kv_caches: Dict[str, torch.Tensor] = {}
|
kv_caches: Dict[str, torch.Tensor] = {}
|
||||||
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator(
|
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator(
|
||||||
):
|
):
|
||||||
@@ -2529,10 +2533,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
rope_cache = torch.zeros(rope_allocate_shape_alignment,
|
rope_cache = torch.zeros(rope_allocate_shape_alignment,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
nope_cache = align_memory(
|
nope_cache = self._align_memory(
|
||||||
nope_cache,
|
nope_cache,
|
||||||
alignment)[:nope_allocate_shape].view(nope_cache_shape)
|
alignment)[:nope_allocate_shape].view(nope_cache_shape)
|
||||||
rope_cache = align_memory(
|
rope_cache = self._align_memory(
|
||||||
rope_cache,
|
rope_cache,
|
||||||
alignment)[:rope_allocate_shape].view(rope_cache_shape)
|
alignment)[:rope_allocate_shape].view(rope_cache_shape)
|
||||||
kv_caches[layer_name] = (nope_cache, rope_cache)
|
kv_caches[layer_name] = (nope_cache, rope_cache)
|
||||||
@@ -2555,7 +2559,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
corresponding memory buffer for KV cache.
|
corresponding memory buffer for KV cache.
|
||||||
"""
|
"""
|
||||||
# init kv cache tensors
|
# init kv cache tensors
|
||||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
kv_cache_raw_tensors: dict[str, Union[torch.Tensor,
|
||||||
|
Optional[torch.Tensor]]] = {}
|
||||||
|
# llmdatadist need the addr of cache tensor be aligned with 2M
|
||||||
|
alignment = 2 * 1024 * 1024
|
||||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||||
# TODO: REFACTOR ME to sharing hybrid cache
|
# TODO: REFACTOR ME to sharing hybrid cache
|
||||||
for idx in range(len(kv_cache_tensor.shared_by)):
|
for idx in range(len(kv_cache_tensor.shared_by)):
|
||||||
@@ -2565,15 +2572,40 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if "self_attn" in layer_name_inner or layer_name_inner in kv_cache_raw_tensors.keys(
|
if "self_attn" in layer_name_inner or layer_name_inner in kv_cache_raw_tensors.keys(
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
tensor = torch.zeros(kv_cache_tensor.size,
|
if self.vllm_config.kv_transfer_config is None:
|
||||||
dtype=torch.int8,
|
tensor = torch.zeros(kv_cache_tensor.size,
|
||||||
device=self.device)
|
dtype=torch.int8,
|
||||||
|
device=self.device)
|
||||||
|
else:
|
||||||
|
cache_size_aligned = kv_cache_tensor.size + alignment
|
||||||
|
tensor = torch.zeros(cache_size_aligned,
|
||||||
|
dtype=torch.int8,
|
||||||
|
device=self.device)
|
||||||
|
tensor = self._align_memory(
|
||||||
|
tensor, alignment)[:kv_cache_tensor.size]
|
||||||
kv_cache_raw_tensors[layer_name_inner] = tensor
|
kv_cache_raw_tensors[layer_name_inner] = tensor
|
||||||
elif "self_attn" in layer_name:
|
elif "self_attn" in layer_name:
|
||||||
tensor = torch.zeros(kv_cache_tensor.size,
|
if self.vllm_config.kv_transfer_config is None:
|
||||||
dtype=torch.int8,
|
k_tensor = torch.zeros(kv_cache_tensor.size // 2,
|
||||||
device=self.device)
|
dtype=torch.int8,
|
||||||
kv_cache_raw_tensors[layer_name] = tensor
|
device=self.device)
|
||||||
|
v_tensor = torch.zeros(kv_cache_tensor.size // 2,
|
||||||
|
dtype=torch.int8,
|
||||||
|
device=self.device)
|
||||||
|
else:
|
||||||
|
cache_size = kv_cache_tensor.size // 2
|
||||||
|
cache_size_aligned = kv_cache_tensor.size // 2 + alignment
|
||||||
|
k_tensor = torch.zeros(cache_size_aligned,
|
||||||
|
dtype=torch.int8,
|
||||||
|
device=self.device)
|
||||||
|
v_tensor = torch.zeros(cache_size_aligned,
|
||||||
|
dtype=torch.int8,
|
||||||
|
device=self.device)
|
||||||
|
k_tensor = self._align_memory(k_tensor,
|
||||||
|
alignment)[:cache_size]
|
||||||
|
v_tensor = self._align_memory(v_tensor,
|
||||||
|
alignment)[:cache_size]
|
||||||
|
kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor)
|
||||||
|
|
||||||
layer_names = set()
|
layer_names = set()
|
||||||
for group in kv_cache_config.kv_cache_groups:
|
for group in kv_cache_config.kv_cache_groups:
|
||||||
@@ -2591,24 +2623,28 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
for layer_name in kv_cache_group.layer_names:
|
for layer_name in kv_cache_group.layer_names:
|
||||||
if layer_name in self.runner_only_attn_layers:
|
if layer_name in self.runner_only_attn_layers:
|
||||||
continue
|
continue
|
||||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
|
||||||
|
|
||||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
|
||||||
num_blocks = raw_tensor.numel(
|
|
||||||
) // kv_cache_spec.page_size_bytes
|
|
||||||
|
|
||||||
# `num_blocks` is the number of blocks the model runner can use.
|
|
||||||
# `kv_cache_config.num_blocks` is the number of blocks that
|
|
||||||
# KVCacheManager may allocate.
|
|
||||||
# Since different GPUs may have different number of layers and
|
|
||||||
# different memory capacities, `num_blocks` can be different on
|
|
||||||
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
|
||||||
# the min of all `num_blocks`. Verify it here.
|
|
||||||
assert num_blocks >= kv_cache_config.num_blocks
|
|
||||||
|
|
||||||
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
||||||
# encounter OOM issue
|
# encounter OOM issue
|
||||||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
if isinstance(kv_cache_spec, FullAttentionSpec):
|
||||||
|
raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore
|
||||||
|
layer_name]
|
||||||
|
assert raw_k_tensor is not None
|
||||||
|
assert raw_v_tensor is not None
|
||||||
|
assert (raw_k_tensor.numel() + raw_v_tensor.numel()
|
||||||
|
) % kv_cache_spec.page_size_bytes == 0
|
||||||
|
num_blocks = (raw_k_tensor.numel() + raw_v_tensor.numel()
|
||||||
|
) // kv_cache_spec.page_size_bytes
|
||||||
|
|
||||||
|
# `num_blocks` is the number of blocks the model runner can use.
|
||||||
|
# `kv_cache_config.num_blocks` is the number of blocks that
|
||||||
|
# KVCacheManager may allocate.
|
||||||
|
# Since different GPUs may have different number of layers and
|
||||||
|
# different memory capacities, `num_blocks` can be different on
|
||||||
|
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
||||||
|
# the min of all `num_blocks`. Verify it here.
|
||||||
|
assert num_blocks >= kv_cache_config.num_blocks
|
||||||
|
|
||||||
if self.vllm_config.additional_config.get(
|
if self.vllm_config.additional_config.get(
|
||||||
"kv_cache_dtype", None) == 'int8':
|
"kv_cache_dtype", None) == 'int8':
|
||||||
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
||||||
@@ -2616,8 +2652,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
kv_cache_spec.num_kv_heads,
|
kv_cache_spec.num_kv_heads,
|
||||||
kv_cache_spec.head_size)
|
kv_cache_spec.head_size)
|
||||||
elif hasattr(attn_backend, "get_supported_block_size"
|
elif hasattr(attn_backend, "get_supported_block_size"
|
||||||
) and not self.model_config.is_deepseek_mla:
|
) and self.use_hybrid_blocks:
|
||||||
block_size = attn_backend.get_supported_block_size()[0]
|
block_size = attn_backend.get_supported_block_size()[0]
|
||||||
|
|
||||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||||
num_blocks * block_size_chunk, block_size,
|
num_blocks * block_size_chunk, block_size,
|
||||||
@@ -2629,11 +2666,28 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
kv_cache_spec.num_kv_heads,
|
kv_cache_spec.num_kv_heads,
|
||||||
kv_cache_spec.head_size)
|
kv_cache_spec.head_size)
|
||||||
dtype = kv_cache_spec.dtype
|
dtype = kv_cache_spec.dtype
|
||||||
kv_cache = raw_tensor.view(dtype).view(kv_cache_shape)
|
k_cache = raw_k_tensor.view(dtype).view(kv_cache_shape[1:])
|
||||||
kv_cache = self._convert_torch_format(kv_cache)
|
k_cache = self._convert_torch_format(k_cache)
|
||||||
kv_caches[layer_name] = kv_cache
|
v_cache = raw_v_tensor.view(dtype).view(kv_cache_shape[1:])
|
||||||
|
v_cache = self._convert_torch_format(v_cache)
|
||||||
|
kv_caches[layer_name] = (k_cache, v_cache)
|
||||||
elif isinstance(kv_cache_spec, MambaSpec):
|
elif isinstance(kv_cache_spec, MambaSpec):
|
||||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||||
|
assert raw_tensor is not None
|
||||||
|
assert raw_tensor.numel(
|
||||||
|
) % kv_cache_spec.page_size_bytes == 0
|
||||||
|
num_blocks = raw_tensor.numel(
|
||||||
|
) // kv_cache_spec.page_size_bytes
|
||||||
|
|
||||||
|
# `num_blocks` is the number of blocks the model runner can use.
|
||||||
|
# `kv_cache_config.num_blocks` is the number of blocks that
|
||||||
|
# KVCacheManager may allocate.
|
||||||
|
# Since different GPUs may have different number of layers and
|
||||||
|
# different memory capacities, `num_blocks` can be different on
|
||||||
|
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
||||||
|
# the min of all `num_blocks`. Verify it here.
|
||||||
|
assert num_blocks >= kv_cache_config.num_blocks
|
||||||
|
|
||||||
state_tensors = []
|
state_tensors = []
|
||||||
storage_offset_bytes = 0
|
storage_offset_bytes = 0
|
||||||
for (shape, dtype) in zip(kv_cache_spec.shapes,
|
for (shape, dtype) in zip(kv_cache_spec.shapes,
|
||||||
@@ -2702,7 +2756,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
attn_groups = self.attn_groups[kv_cache_group_id]
|
attn_groups = self.attn_groups[kv_cache_group_id]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
attn_groups = None
|
attn_groups = None
|
||||||
if attn_groups:
|
if attn_groups and self.use_hybrid_blocks:
|
||||||
# Use the backend's supported block size list
|
# Use the backend's supported block size list
|
||||||
backend = attn_groups[0].backend
|
backend = attn_groups[0].backend
|
||||||
supported_sizes = backend.get_supported_block_size()
|
supported_sizes = backend.get_supported_block_size()
|
||||||
@@ -2713,13 +2767,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
[self.cache_config.block_size])
|
[self.cache_config.block_size])
|
||||||
else:
|
else:
|
||||||
# Fallback to cache config block_size if no backend found
|
# Fallback to cache config block_size if no backend found
|
||||||
kernel_block_size_list = [
|
kernel_block_size_list = [self.cache_config.block_size]
|
||||||
64
|
|
||||||
] if not self.model_config.is_deepseek_mla else [0]
|
|
||||||
kernel_block_sizes.append(kernel_block_size_list)
|
kernel_block_sizes.append(kernel_block_size_list)
|
||||||
else:
|
else:
|
||||||
# This is likely Mamba or other non-attention cache,
|
# This is likely Mamba or other non-attention cache,
|
||||||
# no splitting.
|
# no splitting.
|
||||||
|
# NOTE: set kernel_block_sizes to 0 to disable slotmapping computation
|
||||||
|
# of mamba block. In this case, BlockTable.block_size will never equal
|
||||||
|
# to kernel_block_sizes[0]
|
||||||
kernel_block_sizes.append([0])
|
kernel_block_sizes.append([0])
|
||||||
if kernel_block_sizes != [self.cache_config.block_size]:
|
if kernel_block_sizes != [self.cache_config.block_size]:
|
||||||
assert self.cache_config.cpu_offload_gb == 0, (
|
assert self.cache_config.cpu_offload_gb == 0, (
|
||||||
|
|||||||
Reference in New Issue
Block a user