[bugfix] adapt to new implemented get_kv_cache_spec in cpuoffload connector (#4311)

### What this PR does / why we need it?
func `get_kv_cache_spec` in model_runner changed a lot and caused error
in cpuoffloading connector which is copied from model_runner, this PR
adapts to new implemented `get_kv_cache_spec` to fix it.

### How was this patch tested?

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

Signed-off-by: lidenghui <lidenghui1110@gmail.com>
This commit is contained in:
lidenghui1110
2026-01-08 09:15:09 +08:00
committed by GitHub
parent f7db812ed7
commit 481138e1d2
2 changed files with 85 additions and 31 deletions

View File

@@ -10,18 +10,20 @@ from typing import TYPE_CHECKING, Any, Optional, Sequence
import torch import torch
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention, MLAAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
MLAAttentionSpec) MambaSpec, MLAAttentionSpec)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.cpu_offload_manager.metadata import ( from vllm_ascend.distributed.cpu_offload_manager.metadata import (
MetadataServer, MetadataServerProc, MLAConfig) MetadataServer, MetadataServerProc, MLAConfig)
@@ -435,41 +437,92 @@ class CPUOffloadingConnectorWorker:
save_block_mapping.clear() save_block_mapping.clear()
# Copied from vllm_ascend/worker/model_runner_v1.py. # copied and modified from vllm_ascend/worker/model_runner_v1.py
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
forward_ctx = vllm_config.compilation_config.static_forward_context """
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
if has_ec_transfer() and get_ec_transfer().is_producer:
return {}
block_size = vllm_config.cache_config.block_size block_size = vllm_config.cache_config.block_size
use_mla = vllm_config.model_config.use_mla use_mla = vllm_config.model_config.use_mla
ascend_config = get_ascend_config() use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
use_sfa = ascend_config.use_sfa if vllm_config.cache_config.cache_dtype == "auto":
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_dtype = vllm_config.model_config.dtype
for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE):
continue
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
if use_mla and not use_sfa:
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype)
else: else:
# TODO(cmq): This is a hack way to fix deepseek kvcache when kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
# using DSA. Fix the spec in vLLM is a finnal way. vllm_config.cache_config.cache_dtype]
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
for layer_name, attn_module in attn_layers.items():
if isinstance(attn_module, Attention):
# TODO: Support other attention modules, e.g., cross-attention
# TODO(lucas): move the attention specs into the model layers like
# the attention backends
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec( kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=attn_module.dtype) dtype=kv_cache_dtype)
elif attn_module.attn_type in (AttentionType.ENCODER, elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY): AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.
continue continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER: elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError raise NotImplementedError
else: else:
raise ValueError( raise ValueError(
f"Unknown attention type: {attn_module.attn_type}") f"Unknown attention type: {attn_module.attn_type}")
elif isinstance(attn_module, MLAAttention):
if use_mla and not use_sparse:
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=kv_cache_dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype)
else:
# TODO(cmq): This is a hack way to fix deepseek kvcache when
# using DSA. Fix the spec in vLLM is a finnal way.
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=kv_cache_dtype)
mamba_layers = get_layers_from_vllm_config(vllm_config, MambaBase)
if len(mamba_layers) > 0:
if (vllm_config.speculative_config is not None
and vllm_config.model_config.hf_config.model_type
not in ["qwen3_next"]):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet.")
if vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError(
"Prefix caching is not supported for Mamba yet.")
max_model_len = vllm_config.model_config.max_model_len
page_size_padded = (vllm_config.cache_config.mamba_page_size_padded)
# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
for layer_name, mamba_module in mamba_layers.items():
kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(),
dtypes=mamba_module.get_state_dtype(),
block_size=max_model_len,
page_size_padded=page_size_padded,
mamba_type=mamba_module.mamba_type,
num_speculative_blocks=(
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config else 0),
)
return kv_cache_spec return kv_cache_spec

View File

@@ -12,7 +12,7 @@ from vllm.config import KVTransferConfig, VllmConfig
from vllm.logger import logger from vllm.logger import logger
from vllm.utils.network_utils import make_zmq_socket from vllm.utils.network_utils import make_zmq_socket
from vllm.utils.torch_utils import get_dtype_size from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm_ascend.distributed.cpu_offload_manager.cpu_kv_cache_manager import \ from vllm_ascend.distributed.cpu_offload_manager.cpu_kv_cache_manager import \
CPUKVCacheManager CPUKVCacheManager
@@ -140,14 +140,15 @@ class MetadataServer:
layer.page_size_bytes == any.page_size_bytes layer.page_size_bytes == any.page_size_bytes
for any in kv_cache_specs.values() for any in kv_cache_specs.values()
]) ])
use_mla = isinstance(layer, MLAAttentionSpec)
# mla shares the same kv cache among different tp # mla shares the same kv cache among different tp
if layer.use_mla: if use_mla:
tp_rank = 0 tp_rank = 0
if (pp_rank, tp_rank) in self.shared_memory: if (pp_rank, tp_rank) in self.shared_memory:
return self.shared_memory[(pp_rank, tp_rank)] return self.shared_memory[(pp_rank, tp_rank)]
available_memory = self.available_memory available_memory = self.available_memory
shared_memory_dict = {} shared_memory_dict = {}
if layer.use_mla: if use_mla:
available_memory //= self.pipeline_parallel_size available_memory //= self.pipeline_parallel_size
available_memory //= len(kv_cache_specs) available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes num_blocks = available_memory // layer.page_size_bytes
@@ -165,7 +166,7 @@ class MetadataServer:
shared_memory_dict[ shared_memory_dict[
layer_name] = MetadataServer._safe_create_shared_memory( layer_name] = MetadataServer._safe_create_shared_memory(
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes) f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
if layer.use_mla: if use_mla:
assert mla_config is not None assert mla_config is not None
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
self.shared_memory[(pp_rank, self.shared_memory[(pp_rank,