diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py index 2ea31953..34f6b8a8 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py @@ -20,8 +20,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, - MambaSpec, MLAAttentionSpec) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.metadata import ( MetadataServer, MetadataServerProc, MLAConfig) @@ -461,80 +460,45 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: if has_ec_transfer() and get_ec_transfer().is_producer: return {} - block_size = vllm_config.cache_config.block_size - use_mla = vllm_config.model_config.use_mla use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") if vllm_config.cache_config.cache_dtype == "auto": kv_cache_dtype = vllm_config.model_config.dtype else: kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ vllm_config.cache_config.cache_dtype] + kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase) + # NOTE: Must process Attention/MLAAttention before MambaBase to maintain + # ordering expected by acl_graph.py's _update_attn_fia_params. + mamba_layers: dict[str, MambaBase] = {} for layer_name, attn_module in attn_layers.items(): if isinstance(attn_module, Attention): - # TODO: Support other attention modules, e.g., cross-attention - # TODO(lucas): move the attention specs into the model layers like - # the attention backends - if attn_module.attn_type == AttentionType.DECODER: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=kv_cache_dtype) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - # encoder-only attention does not need KV cache. - continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError - else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") + if spec := attn_module.get_kv_cache_spec(vllm_config): + kv_cache_spec[layer_name] = spec elif isinstance(attn_module, MLAAttention): - if use_mla and not use_sparse: - kv_cache_spec[layer_name] = MLAAttentionSpec( - block_size=block_size, - num_kv_heads=1, - head_size=attn_module.head_size, - dtype=kv_cache_dtype, - cache_dtype_str=vllm_config.cache_config.cache_dtype) - else: + if use_sparse: # TODO(cmq): This is a hack way to fix deepseek kvcache when - # using DSA. Fix the spec in vLLM is a finnal way. + # using DSA. Fix the spec in vLLM is the final way. + block_size = vllm_config.cache_config.block_size kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=1, head_size=attn_module.head_size, dtype=kv_cache_dtype) + elif spec := attn_module.get_kv_cache_spec(vllm_config): + kv_cache_spec[layer_name] = spec + + elif isinstance(attn_module, MambaBase): + mamba_layers[layer_name] = attn_module - mamba_layers = get_layers_from_vllm_config(vllm_config, MambaBase) if len(mamba_layers) > 0: - if (vllm_config.speculative_config is not None - and vllm_config.model_config.hf_config.model_type - not in ["qwen3_next"]): - raise NotImplementedError( - "Mamba with speculative decoding is not supported yet.") if vllm_config.cache_config.enable_prefix_caching: raise NotImplementedError( "Prefix caching is not supported for Mamba yet.") - max_model_len = vllm_config.model_config.max_model_len - - page_size_padded = (vllm_config.cache_config.mamba_page_size_padded) - - # Set block_size to max_model_len, so that mamba model will always - # have only one block in the KV cache. for layer_name, mamba_module in mamba_layers.items(): - kv_cache_spec[layer_name] = MambaSpec( - shapes=mamba_module.get_state_shape(), - dtypes=mamba_module.get_state_dtype(), - block_size=max_model_len, - page_size_padded=page_size_padded, - mamba_type=mamba_module.mamba_type, - num_speculative_blocks=( - vllm_config.speculative_config.num_speculative_tokens - if vllm_config.speculative_config else 0), - ) + if spec := mamba_module.get_kv_cache_spec(vllm_config): + kv_cache_spec[layer_name] = spec return kv_cache_spec diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 5779c136..210dea7b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -53,12 +53,11 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import (AttentionSpec, CrossAttentionSpec, +from vllm.v1.kv_cache_interface import (AttentionSpec, EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, - MambaSpec, MLAAttentionSpec, - UniformTypeKVCacheSpecs) + MambaSpec, UniformTypeKVCacheSpecs) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, LogprobsLists, LogprobsTensors, ModelRunnerOutput, SamplerOutput, @@ -2886,11 +2885,12 @@ class NPUModelRunner(GPUModelRunner): if has_ec_transfer() and get_ec_transfer().is_producer: return {} - block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) + # NOTE: Must process Attention/MLAAttention before MambaBase to maintain + # ordering expected by acl_graph.py's _update_attn_fia_params. + mamba_layers: dict[str, MambaBase] = {} for layer_name, attn_module in attn_layers.items(): if isinstance(attn_module, Attention): if (kv_tgt_layer := @@ -2905,74 +2905,32 @@ class NPUModelRunner(GPUModelRunner): self.shared_kv_cache_layers[layer_name] = kv_tgt_layer continue - # TODO: Support other attention modules, e.g., cross-attention - # TODO(lucas): move the attention specs into the model layers like - # the attention backends - if attn_module.attn_type == AttentionType.DECODER: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - # encoder-only attention does not need KV cache. - continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - kv_cache_spec[layer_name] = CrossAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) - else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") + if spec := attn_module.get_kv_cache_spec(self.vllm_config): + kv_cache_spec[layer_name] = spec elif isinstance(attn_module, MLAAttention): - if use_mla and not self.use_sparse: - kv_cache_spec[layer_name] = MLAAttentionSpec( - block_size=block_size, - num_kv_heads=1, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - cache_dtype_str=self.cache_config.cache_dtype) - else: + if self.use_sparse: # TODO(cmq): This is a hack way to fix deepseek kvcache when - # using DSA. Fix the spec in vLLM is a finnal way. + # using DSA. Fix the spec in vLLM is the final way. + block_size = self.vllm_config.cache_config.block_size kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=1, head_size=attn_module.head_size, dtype=self.kv_cache_dtype) + elif spec := attn_module.get_kv_cache_spec(self.vllm_config): + kv_cache_spec[layer_name] = spec + + elif isinstance(attn_module, MambaBase): + mamba_layers[layer_name] = attn_module - mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) if len(mamba_layers) > 0: - if (self.vllm_config.speculative_config is not None - and self.vllm_config.model_config.hf_text_config.model_type - not in ["qwen3_next"]): - raise NotImplementedError( - "Mamba with speculative decoding is not supported yet.") if self.vllm_config.cache_config.enable_prefix_caching: raise NotImplementedError( "Prefix caching is not supported for Mamba yet.") - max_model_len = self.vllm_config.model_config.max_model_len - - page_size_padded = ( - self.vllm_config.cache_config.mamba_page_size_padded) - - # Set block_size to max_model_len, so that mamba model will always - # have only one block in the KV cache. for layer_name, mamba_module in mamba_layers.items(): - kv_cache_spec[layer_name] = MambaSpec( - shapes=mamba_module.get_state_shape(), - dtypes=mamba_module.get_state_dtype(), - block_size=max_model_len, - page_size_padded=page_size_padded, - mamba_type=mamba_module.mamba_type, - num_speculative_blocks=( - self.speculative_config.num_speculative_tokens - if self.speculative_config else 0), - ) + if spec := mamba_module.get_kv_cache_spec(self.vllm_config): + kv_cache_spec[layer_name] = spec return kv_cache_spec