diff --git a/vllm_ascend/distributed/cpu_offload_connector.py b/vllm_ascend/distributed/cpu_offload_connector.py index 5a9ddd2e..9bcde279 100644 --- a/vllm_ascend/distributed/cpu_offload_connector.py +++ b/vllm_ascend/distributed/cpu_offload_connector.py @@ -10,18 +10,20 @@ from typing import TYPE_CHECKING, Any, Optional, Sequence import torch from vllm.attention.backends.abstract import AttentionType -from vllm.attention.layer import Attention -from vllm.config import VllmConfig +from vllm.attention.layer import Attention, MLAAttention +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import logger -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, - MLAAttentionSpec) + MambaSpec, MLAAttentionSpec) -from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.cpu_offload_manager.metadata import ( MetadataServer, MetadataServerProc, MLAConfig) @@ -435,41 +437,92 @@ class CPUOffloadingConnectorWorker: save_block_mapping.clear() -# Copied from vllm_ascend/worker/model_runner_v1.py. +# copied and modified from vllm_ascend/worker/model_runner_v1.py def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: - forward_ctx = vllm_config.compilation_config.static_forward_context + """ + Generates the KVCacheSpec by parsing the kv cache format from each + Attention module in the static forward context. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ + if has_ec_transfer() and get_ec_transfer().is_producer: + return {} + block_size = vllm_config.cache_config.block_size use_mla = vllm_config.model_config.use_mla - ascend_config = get_ascend_config() - use_sfa = ascend_config.use_sfa + use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") + if vllm_config.cache_config.cache_dtype == "auto": + kv_cache_dtype = vllm_config.model_config.dtype + else: + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + vllm_config.cache_config.cache_dtype] kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): - if isinstance(attn_module, FusedMoE): - continue - assert isinstance(attn_module, Attention) - if attn_module.attn_type == AttentionType.DECODER: - if use_mla and not use_sfa: - kv_cache_spec[layer_name] = MLAAttentionSpec( + attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase) + for layer_name, attn_module in attn_layers.items(): + if isinstance(attn_module, Attention): + # TODO: Support other attention modules, e.g., cross-attention + # TODO(lucas): move the attention specs into the model layers like + # the attention backends + if attn_module.attn_type == AttentionType.DECODER: + kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=attn_module.dtype, + dtype=kv_cache_dtype) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + # encoder-only attention does not need KV cache. + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + + elif isinstance(attn_module, MLAAttention): + if use_mla and not use_sparse: + kv_cache_spec[layer_name] = MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=attn_module.head_size, + dtype=kv_cache_dtype, cache_dtype_str=vllm_config.cache_config.cache_dtype) else: # TODO(cmq): This is a hack way to fix deepseek kvcache when # using DSA. Fix the spec in vLLM is a finnal way. kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, + num_kv_heads=1, head_size=attn_module.head_size, - dtype=attn_module.dtype) + dtype=kv_cache_dtype) + + mamba_layers = get_layers_from_vllm_config(vllm_config, MambaBase) + if len(mamba_layers) > 0: + if (vllm_config.speculative_config is not None + and vllm_config.model_config.hf_config.model_type + not in ["qwen3_next"]): + raise NotImplementedError( + "Mamba with speculative decoding is not supported yet.") + if vllm_config.cache_config.enable_prefix_caching: + raise NotImplementedError( + "Prefix caching is not supported for Mamba yet.") + max_model_len = vllm_config.model_config.max_model_len + + page_size_padded = (vllm_config.cache_config.mamba_page_size_padded) + + # Set block_size to max_model_len, so that mamba model will always + # have only one block in the KV cache. + for layer_name, mamba_module in mamba_layers.items(): + kv_cache_spec[layer_name] = MambaSpec( + shapes=mamba_module.get_state_shape(), + dtypes=mamba_module.get_state_dtype(), + block_size=max_model_len, + page_size_padded=page_size_padded, + mamba_type=mamba_module.mamba_type, + num_speculative_blocks=( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config else 0), + ) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError - else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") return kv_cache_spec diff --git a/vllm_ascend/distributed/cpu_offload_manager/metadata.py b/vllm_ascend/distributed/cpu_offload_manager/metadata.py index 3dba8ac2..71242a86 100644 --- a/vllm_ascend/distributed/cpu_offload_manager/metadata.py +++ b/vllm_ascend/distributed/cpu_offload_manager/metadata.py @@ -12,7 +12,7 @@ from vllm.config import KVTransferConfig, VllmConfig from vllm.logger import logger from vllm.utils.network_utils import make_zmq_socket from vllm.utils.torch_utils import get_dtype_size -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec from vllm_ascend.distributed.cpu_offload_manager.cpu_kv_cache_manager import \ CPUKVCacheManager @@ -140,14 +140,15 @@ class MetadataServer: layer.page_size_bytes == any.page_size_bytes for any in kv_cache_specs.values() ]) + use_mla = isinstance(layer, MLAAttentionSpec) # mla shares the same kv cache among different tp - if layer.use_mla: + if use_mla: tp_rank = 0 if (pp_rank, tp_rank) in self.shared_memory: return self.shared_memory[(pp_rank, tp_rank)] available_memory = self.available_memory shared_memory_dict = {} - if layer.use_mla: + if use_mla: available_memory //= self.pipeline_parallel_size available_memory //= len(kv_cache_specs) num_blocks = available_memory // layer.page_size_bytes @@ -165,7 +166,7 @@ class MetadataServer: shared_memory_dict[ layer_name] = MetadataServer._safe_create_shared_memory( f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes) - if layer.use_mla: + if use_mla: assert mla_config is not None assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim self.shared_memory[(pp_rank,