[bugfix] adapt to new implemented get_kv_cache_spec in cpuoffload connector (#4311)

### What this PR does / why we need it?
func `get_kv_cache_spec` in model_runner changed a lot and caused error
in cpuoffloading connector which is copied from model_runner, this PR
adapts to new implemented `get_kv_cache_spec` to fix it.

### How was this patch tested?

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

Signed-off-by: lidenghui <lidenghui1110@gmail.com>
This commit is contained in:
lidenghui1110
2026-01-08 09:15:09 +08:00
committed by GitHub
parent f7db812ed7
commit 481138e1d2
2 changed files with 85 additions and 31 deletions

View File

@@ -10,18 +10,20 @@ from typing import TYPE_CHECKING, Any, Optional, Sequence
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.attention.layer import Attention, MLAAttention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
MLAAttentionSpec)
MambaSpec, MLAAttentionSpec)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.cpu_offload_manager.metadata import (
MetadataServer, MetadataServerProc, MLAConfig)
@@ -435,41 +437,92 @@ class CPUOffloadingConnectorWorker:
save_block_mapping.clear()
# Copied from vllm_ascend/worker/model_runner_v1.py.
# copied and modified from vllm_ascend/worker/model_runner_v1.py
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
forward_ctx = vllm_config.compilation_config.static_forward_context
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
if has_ec_transfer() and get_ec_transfer().is_producer:
return {}
block_size = vllm_config.cache_config.block_size
use_mla = vllm_config.model_config.use_mla
ascend_config = get_ascend_config()
use_sfa = ascend_config.use_sfa
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
if vllm_config.cache_config.cache_dtype == "auto":
kv_cache_dtype = vllm_config.model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
vllm_config.cache_config.cache_dtype]
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE):
continue
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
if use_mla and not use_sfa:
kv_cache_spec[layer_name] = MLAAttentionSpec(
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
for layer_name, attn_module in attn_layers.items():
if isinstance(attn_module, Attention):
# TODO: Support other attention modules, e.g., cross-attention
# TODO(lucas): move the attention specs into the model layers like
# the attention backends
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
dtype=kv_cache_dtype)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
elif isinstance(attn_module, MLAAttention):
if use_mla and not use_sparse:
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=kv_cache_dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype)
else:
# TODO(cmq): This is a hack way to fix deepseek kvcache when
# using DSA. Fix the spec in vLLM is a finnal way.
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=attn_module.dtype)
dtype=kv_cache_dtype)
mamba_layers = get_layers_from_vllm_config(vllm_config, MambaBase)
if len(mamba_layers) > 0:
if (vllm_config.speculative_config is not None
and vllm_config.model_config.hf_config.model_type
not in ["qwen3_next"]):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet.")
if vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError(
"Prefix caching is not supported for Mamba yet.")
max_model_len = vllm_config.model_config.max_model_len
page_size_padded = (vllm_config.cache_config.mamba_page_size_padded)
# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
for layer_name, mamba_module in mamba_layers.items():
kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(),
dtypes=mamba_module.get_state_dtype(),
block_size=max_model_len,
page_size_padded=page_size_padded,
mamba_type=mamba_module.mamba_type,
num_speculative_blocks=(
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config else 0),
)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
return kv_cache_spec

View File

@@ -12,7 +12,7 @@ from vllm.config import KVTransferConfig, VllmConfig
from vllm.logger import logger
from vllm.utils.network_utils import make_zmq_socket
from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm_ascend.distributed.cpu_offload_manager.cpu_kv_cache_manager import \
CPUKVCacheManager
@@ -140,14 +140,15 @@ class MetadataServer:
layer.page_size_bytes == any.page_size_bytes
for any in kv_cache_specs.values()
])
use_mla = isinstance(layer, MLAAttentionSpec)
# mla shares the same kv cache among different tp
if layer.use_mla:
if use_mla:
tp_rank = 0
if (pp_rank, tp_rank) in self.shared_memory:
return self.shared_memory[(pp_rank, tp_rank)]
available_memory = self.available_memory
shared_memory_dict = {}
if layer.use_mla:
if use_mla:
available_memory //= self.pipeline_parallel_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
@@ -165,7 +166,7 @@ class MetadataServer:
shared_memory_dict[
layer_name] = MetadataServer._safe_create_shared_memory(
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
if layer.use_mla:
if use_mla:
assert mla_config is not None
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
self.shared_memory[(pp_rank,