diff --git a/tests/ut/worker/test_model_runner_v1.py b/tests/ut/worker/test_model_runner_v1.py new file mode 100644 index 00000000..0cb09faf --- /dev/null +++ b/tests/ut/worker/test_model_runner_v1.py @@ -0,0 +1,89 @@ +import unittest +from types import SimpleNamespace +from unittest.mock import MagicMock + +import torch +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor + +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner + + +class TestNPUModelRunnerKVCache(unittest.TestCase): + + def _build_runner(self): + runner = NPUModelRunner.__new__(NPUModelRunner) + runner.device = torch.device("cpu") + runner.use_sparse = False + runner.use_sparse_c8_indexer = False + runner.use_hybrid_blocks = False + runner.hybrid_with_attn_and_mamba = False + runner.runner_only_attn_layers = set() + runner.is_kv_consumer = False + runner.vllm_config = MagicMock() + runner.vllm_config.kv_transfer_config = None + runner.model_config = MagicMock() + runner.model_config.use_mla = True + backend = MagicMock() + backend.get_kv_cache_shape.side_effect = lambda num_blocks, block_size, num_kv_heads, head_size: ( + 2, + num_blocks, + block_size, + num_kv_heads, + head_size, + ) + runner.attn_backend = backend + return runner + + def test_allocate_kv_cache_uses_layer_spec_for_draft_gqa(self): + runner = self._build_runner() + kv_cache_spec = FullAttentionSpec( + block_size=16, + num_kv_heads=8, + head_size=64, + head_size_v=64, + dtype=torch.float16, + ) + kv_cache_config = KVCacheConfig( + num_blocks=2, + kv_cache_tensors=[KVCacheTensor(size=kv_cache_spec.page_size_bytes * 2, shared_by=["draft_attn"])], + kv_cache_groups=[KVCacheGroupSpec(layer_names=["draft_attn"], kv_cache_spec=kv_cache_spec)], + ) + + kv_cache_raw_tensors = runner._allocate_kv_cache_tensors(kv_cache_config) + k_cache_raw, v_cache_raw = kv_cache_raw_tensors["draft_attn"] + + self.assertEqual(k_cache_raw.numel(), kv_cache_spec.page_size_bytes) + self.assertEqual(v_cache_raw.numel(), kv_cache_spec.page_size_bytes) + + def test_reshape_kv_cache_uses_layer_spec_for_draft_gqa(self): + runner = self._build_runner() + kv_cache_spec = FullAttentionSpec( + block_size=16, + num_kv_heads=8, + head_size=64, + head_size_v=64, + dtype=torch.float16, + ) + kv_cache_config = KVCacheConfig( + num_blocks=2, + kv_cache_tensors=[KVCacheTensor(size=kv_cache_spec.page_size_bytes * 2, shared_by=["draft_attn"])], + kv_cache_groups=[KVCacheGroupSpec(layer_names=["draft_attn"], kv_cache_spec=kv_cache_spec)], + ) + kv_cache_raw_tensors = runner._allocate_kv_cache_tensors(kv_cache_config) + runner._kv_cache_spec_attn_group_iterator = lambda: [ + SimpleNamespace( + kv_cache_spec=kv_cache_spec, + backend=runner.attn_backend, + layer_names=["draft_attn"], + ) + ] + + kv_caches = runner._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors) + k_cache, v_cache = kv_caches["draft_attn"] + + self.assertEqual(k_cache.shape, (2, 16, 8, 64)) + self.assertEqual(v_cache.shape, (2, 16, 8, 64)) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index b9b6c644..869cd287 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -46,7 +46,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params from vllm_ascend.ops.triton.spec_decode.utils import prepare_inputs_padded_kernel from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num -from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled +from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled, vllm_version_is # Currently we will fix block size to a small one since `num_reqs` can't be too large _PREPARE_INPUTS_BLOCK_SIZE = 4 @@ -217,6 +217,8 @@ class SpecDecodeBaseProposer(EagleProposer): self.model.config.image_token_index = model.config.image_token_id elif self.get_model_name(model) == "PixtralForConditionalGeneration": self.model.config.image_token_index = model.config.vision_config.image_token_id + elif self.get_model_name(model) == "KimiK25ForConditionalGeneration": + self.model.config.image_token_index = model.config.media_placeholder_token_id else: self.model.config.image_token_index = model.config.image_token_index target_language_model = model.get_language_model() @@ -388,7 +390,11 @@ class SpecDecodeBaseProposer(EagleProposer): : num_reqs * self.decode_threshold ] - builder = self.runner.attn_groups[0][0].get_metadata_builder() + if vllm_version_is("0.17.0"): + assert len(self.draft_attn_groups) > 0 + builder = self.draft_attn_groups[0].get_metadata_builder() + else: + builder = self.runner.attn_groups[0][0].get_metadata_builder() # update the tensor's address for each step. for draft_step in range(self.num_speculative_tokens): common_attn_metadata = self.shallow_copy_metadata(common_attn_metadata) @@ -550,7 +556,11 @@ class SpecDecodeBaseProposer(EagleProposer): common_attn_metadata.slot_mapping = self.slot_mapping_group[0] common_attn_metadata.num_input_tokens = num_input_tokens # FIXME(woosuk): The below two ops cause synchronization. Optimize. - builder = self.runner.attn_groups[0][0].get_metadata_builder() + if vllm_version_is("0.17.0"): + assert len(self.draft_attn_groups) > 0 + builder = self.draft_attn_groups[0].get_metadata_builder() + else: + builder = self.runner.attn_groups[0][0].get_metadata_builder() attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model()) if self.uses_mrope: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a49182ce..d1f86bf0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2657,6 +2657,34 @@ class NPUModelRunner(GPUModelRunner): bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches, num_attn_module) return kv_caches + def _get_layer_kv_cache_specs(self, kv_cache_config: KVCacheConfig) -> dict[str, KVCacheSpec]: + layer_kv_cache_spec: dict[str, KVCacheSpec] = {} + for group_kv_cache_spec in kv_cache_config.kv_cache_groups: + group_spec = group_kv_cache_spec.kv_cache_spec + for layer_name in group_kv_cache_spec.layer_names: + if isinstance(group_spec, UniformTypeKVCacheSpecs): + layer_kv_cache_spec[layer_name] = group_spec.kv_cache_specs[layer_name] + else: + layer_kv_cache_spec[layer_name] = group_spec + return layer_kv_cache_spec + + def _get_attention_kv_cache_dims(self, layer_name: str, kv_cache_spec: AttentionSpec) -> tuple[int, int]: + if isinstance(kv_cache_spec, MLAAttentionSpec): + attn_layers = get_layers_from_vllm_config( + self.vllm_config, + AttentionLayerBase, + [layer_name], + ) + attn_layer = attn_layers[layer_name] + if not isinstance(attn_layer, MLAAttention): + raise TypeError( + f"Expected MLAAttention layer for {layer_name}, got {type(attn_layer).__name__}." + ) + return attn_layer.kv_lora_rank, attn_layer.qk_rope_head_dim + + head_size_v = kv_cache_spec.head_size_v if hasattr(kv_cache_spec, "head_size_v") else kv_cache_spec.head_size + return kv_cache_spec.head_size, head_size_v + def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ Initializes the KV cache buffer with the correct size. The buffer needs @@ -2677,10 +2705,7 @@ class NPUModelRunner(GPUModelRunner): kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None | None] = {} # prefill disaggregation need the addr of cache tensor be aligned with 2M alignment = 2 * 1024 * 1024 - layer_kv_cache_spec: dict[str, KVCacheSpec] = {} - for group_kv_cache_spec in kv_cache_config.kv_cache_groups: - for layer_name in group_kv_cache_spec.layer_names: - layer_kv_cache_spec[layer_name] = group_kv_cache_spec.kv_cache_spec + layer_kv_cache_spec = self._get_layer_kv_cache_specs(kv_cache_config) # If some tensors are shared by linear layers and attention layers, # the same tensor format must be maintained even if some layers # have only linear or attention layers, for example, the mtp layer. @@ -2715,11 +2740,10 @@ class NPUModelRunner(GPUModelRunner): # as it only support the 0-dim of kv_cache is `num_blocks`. # For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim # and rope head dim. - if not self.model_config.use_mla: - # for non-mla model, use FullAttentionSpec - k_tensor_split_factor = 2.0 - v_tensor_split_factor = 2.0 - elif self.use_sparse: + current_kv_cache_spec = layer_kv_cache_spec[layer_name] + assert isinstance(current_kv_cache_spec, AttentionSpec) + + if self.use_sparse: # for deepseek v3.2, we split the kv cache according to the corresponding ratio kv_cache_spec = layer_kv_cache_spec[layer_name] sparse_kv_cache_ratio = kv_cache_spec.sparse_kv_cache_ratio @@ -2728,10 +2752,11 @@ class NPUModelRunner(GPUModelRunner): dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2] dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3] else: - # for other deepseek models, use MLAAttentionSpec + k_dim, v_dim = self._get_attention_kv_cache_dims(layer_name, current_kv_cache_spec) + assert k_dim > 0 and v_dim > 0 kv_head_dim_list = [ - self.model_config.hf_text_config.kv_lora_rank, - self.model_config.hf_text_config.qk_rope_head_dim, + k_dim, + v_dim, ] if self.is_kv_consumer and self.vllm_config.quant_config is not None: k_tensor_split_factor, v_tensor_split_factor = ( @@ -2819,20 +2844,18 @@ class NPUModelRunner(GPUModelRunner): corresponding memory buffer for KV cache. """ kv_caches: dict[str, torch.Tensor] = {} - layer_kv_cache_spec = {} - for group in kv_cache_config.kv_cache_groups: - for layer_name in group.layer_names: - layer_kv_cache_spec[layer_name] = group.kv_cache_spec + layer_kv_cache_spec = self._get_layer_kv_cache_specs(kv_cache_config) for group in self._kv_cache_spec_attn_group_iterator(): - kv_cache_spec = group.kv_cache_spec attn_backend = group.backend for layer_name in group.layer_names: if layer_name in self.runner_only_attn_layers: continue + current_kv_cache_spec = layer_kv_cache_spec[layer_name] + # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may # encounter OOM issue - if isinstance(kv_cache_spec, AttentionSpec): + if isinstance(current_kv_cache_spec, AttentionSpec): if self.use_sparse: if self.use_sparse_c8_indexer: raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor, raw_dsa_k_scale_tensor = kv_cache_raw_tensors[ # type: ignore @@ -2862,8 +2885,8 @@ class NPUModelRunner(GPUModelRunner): sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() assert raw_k_tensor is not None assert raw_v_tensor is not None - assert sum_page_size_bytes % kv_cache_spec.page_size_bytes == 0 - num_blocks = sum_page_size_bytes // kv_cache_spec.page_size_bytes + assert sum_page_size_bytes % current_kv_cache_spec.page_size_bytes == 0 + num_blocks = sum_page_size_bytes // current_kv_cache_spec.page_size_bytes # `num_blocks` is the number of blocks the model runner can use. # `kv_cache_config.num_blocks` is the number of blocks that @@ -2877,48 +2900,54 @@ class NPUModelRunner(GPUModelRunner): if hasattr(attn_backend, "get_supported_kernel_block_sizes") and self.use_hybrid_blocks: block_size = attn_backend.get_supported_kernel_block_sizes()[0] - block_size_chunk = kv_cache_spec.block_size // block_size + block_size_chunk = current_kv_cache_spec.block_size // block_size kv_cache_shape = attn_backend.get_kv_cache_shape( num_blocks * block_size_chunk, block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, + current_kv_cache_spec.num_kv_heads, + current_kv_cache_spec.head_size, ) if self.hybrid_with_attn_and_mamba: attn_tensor_page_size = int(np.prod(kv_cache_shape[1:])) * get_dtype_size( - kv_cache_spec.dtype + current_kv_cache_spec.dtype ) conv_block_padding_size = raw_k_tensor.numel() - attn_tensor_page_size * 2 raw_kv_tensor = raw_k_tensor[conv_block_padding_size:] raw_k_tensor = raw_kv_tensor[:attn_tensor_page_size] raw_v_tensor = raw_kv_tensor[attn_tensor_page_size:] else: - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks, + current_kv_cache_spec.block_size, + current_kv_cache_spec.num_kv_heads, + current_kv_cache_spec.head_size, ) - - if not self.model_config.use_mla: + if not isinstance(current_kv_cache_spec, MLAAttentionSpec): k_shape = kv_cache_shape[1:] - v_shape = k_shape + if hasattr(current_kv_cache_spec, "head_size_v"): + v_shape = (*kv_cache_shape[1:-1], current_kv_cache_spec.head_size_v) + else: + v_shape = k_shape else: # k_cache: nope_cache v_cache: rope_cache mla_num_blocks, mla_block_size, num_kv_heads, _ = kv_cache_shape - k_shape = [ + k_dim, v_dim = self._get_attention_kv_cache_dims(layer_name, current_kv_cache_spec) + k_shape = ( mla_num_blocks, mla_block_size, num_kv_heads, - self.model_config.hf_text_config.kv_lora_rank, - ] - v_shape = [ + k_dim, + ) + v_shape = ( mla_num_blocks, mla_block_size, num_kv_heads, - self.model_config.hf_text_config.qk_rope_head_dim, - ] - k_cache_dtype = v_cache_dtype = kv_cache_spec.dtype + v_dim, + ) + k_cache_dtype = v_cache_dtype = current_kv_cache_spec.dtype if self.is_kv_consumer and self.vllm_config.quant_config is not None: k_cache_dtype, v_cache_dtype = self.vllm_config.quant_config.get_kv_quant_dtype( - layer_name, kv_cache_spec.dtype, self.model_config + layer_name, current_kv_cache_spec.dtype, self.model_config ) k_cache = raw_k_tensor.view(k_cache_dtype).view(k_shape) v_cache = raw_v_tensor.view(v_cache_dtype).view(v_shape) @@ -2926,8 +2955,8 @@ class NPUModelRunner(GPUModelRunner): if self.use_sparse: dsa_k_cache_shape = ( num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, + current_kv_cache_spec.block_size, + current_kv_cache_spec.num_kv_heads, self.model_config.hf_text_config.index_head_dim, ) if self.use_sparse_c8_indexer: @@ -2936,8 +2965,8 @@ class NPUModelRunner(GPUModelRunner): # dsa_k_scale dsa_k_scale_cache_shape = ( num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, + current_kv_cache_spec.block_size, + current_kv_cache_spec.num_kv_heads, 1, ) assert raw_dsa_k_scale_tensor is not None @@ -2949,15 +2978,15 @@ class NPUModelRunner(GPUModelRunner): kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache, dsa_k_scale_cache) else: # dsa_k - dsa_k_cache = raw_dsa_k_tensor.view(kv_cache_spec.dtype).view(dsa_k_cache_shape) + dsa_k_cache = raw_dsa_k_tensor.view(current_kv_cache_spec.dtype).view(dsa_k_cache_shape) kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache) else: kv_caches[layer_name] = (k_cache, v_cache) - elif isinstance(kv_cache_spec, MambaSpec): + elif isinstance(current_kv_cache_spec, MambaSpec): raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor is not None - assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes + assert raw_tensor.numel() % current_kv_cache_spec.page_size_bytes == 0 + num_blocks = raw_tensor.numel() // current_kv_cache_spec.page_size_bytes assert num_blocks >= kv_cache_config.num_blocks # `num_blocks` is the number of blocks the model runner can use. @@ -2977,7 +3006,7 @@ class NPUModelRunner(GPUModelRunner): # tensor1: [(kv_padding), conv , ...] # tensor2: [k , ssm , ...] # tensor3: [v , (mamba_padding), ...] - for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): + for shape, dtype in zip(current_kv_cache_spec.shapes, current_kv_cache_spec.dtypes): # normally, there is conv state and ssm state in this loop. And there is only # a conv state in some special models. target_shape = (num_blocks, *shape)