[bugfix] adapt to new implemented get_kv_cache_spec in cpuoffload connector (#4311)
### What this PR does / why we need it?
func `get_kv_cache_spec` in model_runner changed a lot and caused error
in cpuoffloading connector which is copied from model_runner, this PR
adapts to new implemented `get_kv_cache_spec` to fix it.
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: lidenghui <lidenghui1110@gmail.com>
This commit is contained in:
@@ -10,18 +10,20 @@ from typing import TYPE_CHECKING, Any, Optional, Sequence
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention, MLAAttention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
|
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
|
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||||
MLAAttentionSpec)
|
MambaSpec, MLAAttentionSpec)
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
|
||||||
from vllm_ascend.distributed.cpu_offload_manager.metadata import (
|
from vllm_ascend.distributed.cpu_offload_manager.metadata import (
|
||||||
MetadataServer, MetadataServerProc, MLAConfig)
|
MetadataServer, MetadataServerProc, MLAConfig)
|
||||||
|
|
||||||
@@ -435,41 +437,92 @@ class CPUOffloadingConnectorWorker:
|
|||||||
save_block_mapping.clear()
|
save_block_mapping.clear()
|
||||||
|
|
||||||
|
|
||||||
# Copied from vllm_ascend/worker/model_runner_v1.py.
|
# copied and modified from vllm_ascend/worker/model_runner_v1.py
|
||||||
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||||
forward_ctx = vllm_config.compilation_config.static_forward_context
|
"""
|
||||||
|
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||||
|
Attention module in the static forward context.
|
||||||
|
Returns:
|
||||||
|
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||||
|
format. Layers that do not need KV cache are not included.
|
||||||
|
"""
|
||||||
|
if has_ec_transfer() and get_ec_transfer().is_producer:
|
||||||
|
return {}
|
||||||
|
|
||||||
block_size = vllm_config.cache_config.block_size
|
block_size = vllm_config.cache_config.block_size
|
||||||
use_mla = vllm_config.model_config.use_mla
|
use_mla = vllm_config.model_config.use_mla
|
||||||
ascend_config = get_ascend_config()
|
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||||
use_sfa = ascend_config.use_sfa
|
if vllm_config.cache_config.cache_dtype == "auto":
|
||||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
kv_cache_dtype = vllm_config.model_config.dtype
|
||||||
for layer_name, attn_module in forward_ctx.items():
|
|
||||||
if isinstance(attn_module, FusedMoE):
|
|
||||||
continue
|
|
||||||
assert isinstance(attn_module, Attention)
|
|
||||||
if attn_module.attn_type == AttentionType.DECODER:
|
|
||||||
if use_mla and not use_sfa:
|
|
||||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
|
||||||
block_size=block_size,
|
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
|
||||||
head_size=attn_module.head_size,
|
|
||||||
dtype=attn_module.dtype,
|
|
||||||
cache_dtype_str=vllm_config.cache_config.cache_dtype)
|
|
||||||
else:
|
else:
|
||||||
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||||
# using DSA. Fix the spec in vLLM is a finnal way.
|
vllm_config.cache_config.cache_dtype]
|
||||||
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||||
|
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
|
||||||
|
for layer_name, attn_module in attn_layers.items():
|
||||||
|
if isinstance(attn_module, Attention):
|
||||||
|
# TODO: Support other attention modules, e.g., cross-attention
|
||||||
|
# TODO(lucas): move the attention specs into the model layers like
|
||||||
|
# the attention backends
|
||||||
|
if attn_module.attn_type == AttentionType.DECODER:
|
||||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
head_size=attn_module.head_size,
|
head_size=attn_module.head_size,
|
||||||
dtype=attn_module.dtype)
|
dtype=kv_cache_dtype)
|
||||||
|
|
||||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||||
AttentionType.ENCODER_ONLY):
|
AttentionType.ENCODER_ONLY):
|
||||||
|
# encoder-only attention does not need KV cache.
|
||||||
continue
|
continue
|
||||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown attention type: {attn_module.attn_type}")
|
f"Unknown attention type: {attn_module.attn_type}")
|
||||||
|
|
||||||
|
elif isinstance(attn_module, MLAAttention):
|
||||||
|
if use_mla and not use_sparse:
|
||||||
|
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=attn_module.head_size,
|
||||||
|
dtype=kv_cache_dtype,
|
||||||
|
cache_dtype_str=vllm_config.cache_config.cache_dtype)
|
||||||
|
else:
|
||||||
|
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
||||||
|
# using DSA. Fix the spec in vLLM is a finnal way.
|
||||||
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=attn_module.head_size,
|
||||||
|
dtype=kv_cache_dtype)
|
||||||
|
|
||||||
|
mamba_layers = get_layers_from_vllm_config(vllm_config, MambaBase)
|
||||||
|
if len(mamba_layers) > 0:
|
||||||
|
if (vllm_config.speculative_config is not None
|
||||||
|
and vllm_config.model_config.hf_config.model_type
|
||||||
|
not in ["qwen3_next"]):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Mamba with speculative decoding is not supported yet.")
|
||||||
|
if vllm_config.cache_config.enable_prefix_caching:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Prefix caching is not supported for Mamba yet.")
|
||||||
|
max_model_len = vllm_config.model_config.max_model_len
|
||||||
|
|
||||||
|
page_size_padded = (vllm_config.cache_config.mamba_page_size_padded)
|
||||||
|
|
||||||
|
# Set block_size to max_model_len, so that mamba model will always
|
||||||
|
# have only one block in the KV cache.
|
||||||
|
for layer_name, mamba_module in mamba_layers.items():
|
||||||
|
kv_cache_spec[layer_name] = MambaSpec(
|
||||||
|
shapes=mamba_module.get_state_shape(),
|
||||||
|
dtypes=mamba_module.get_state_dtype(),
|
||||||
|
block_size=max_model_len,
|
||||||
|
page_size_padded=page_size_padded,
|
||||||
|
mamba_type=mamba_module.mamba_type,
|
||||||
|
num_speculative_blocks=(
|
||||||
|
vllm_config.speculative_config.num_speculative_tokens
|
||||||
|
if vllm_config.speculative_config else 0),
|
||||||
|
)
|
||||||
|
|
||||||
return kv_cache_spec
|
return kv_cache_spec
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from vllm.config import KVTransferConfig, VllmConfig
|
|||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.utils.network_utils import make_zmq_socket
|
from vllm.utils.network_utils import make_zmq_socket
|
||||||
from vllm.utils.torch_utils import get_dtype_size
|
from vllm.utils.torch_utils import get_dtype_size
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
||||||
|
|
||||||
from vllm_ascend.distributed.cpu_offload_manager.cpu_kv_cache_manager import \
|
from vllm_ascend.distributed.cpu_offload_manager.cpu_kv_cache_manager import \
|
||||||
CPUKVCacheManager
|
CPUKVCacheManager
|
||||||
@@ -140,14 +140,15 @@ class MetadataServer:
|
|||||||
layer.page_size_bytes == any.page_size_bytes
|
layer.page_size_bytes == any.page_size_bytes
|
||||||
for any in kv_cache_specs.values()
|
for any in kv_cache_specs.values()
|
||||||
])
|
])
|
||||||
|
use_mla = isinstance(layer, MLAAttentionSpec)
|
||||||
# mla shares the same kv cache among different tp
|
# mla shares the same kv cache among different tp
|
||||||
if layer.use_mla:
|
if use_mla:
|
||||||
tp_rank = 0
|
tp_rank = 0
|
||||||
if (pp_rank, tp_rank) in self.shared_memory:
|
if (pp_rank, tp_rank) in self.shared_memory:
|
||||||
return self.shared_memory[(pp_rank, tp_rank)]
|
return self.shared_memory[(pp_rank, tp_rank)]
|
||||||
available_memory = self.available_memory
|
available_memory = self.available_memory
|
||||||
shared_memory_dict = {}
|
shared_memory_dict = {}
|
||||||
if layer.use_mla:
|
if use_mla:
|
||||||
available_memory //= self.pipeline_parallel_size
|
available_memory //= self.pipeline_parallel_size
|
||||||
available_memory //= len(kv_cache_specs)
|
available_memory //= len(kv_cache_specs)
|
||||||
num_blocks = available_memory // layer.page_size_bytes
|
num_blocks = available_memory // layer.page_size_bytes
|
||||||
@@ -165,7 +166,7 @@ class MetadataServer:
|
|||||||
shared_memory_dict[
|
shared_memory_dict[
|
||||||
layer_name] = MetadataServer._safe_create_shared_memory(
|
layer_name] = MetadataServer._safe_create_shared_memory(
|
||||||
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
|
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
|
||||||
if layer.use_mla:
|
if use_mla:
|
||||||
assert mla_config is not None
|
assert mla_config is not None
|
||||||
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
|
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
|
||||||
self.shared_memory[(pp_rank,
|
self.shared_memory[(pp_rank,
|
||||||
|
|||||||
Reference in New Issue
Block a user