From 38570cfeb6e590913dccf9b37a0fae8d9f71a014 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 31 Dec 2025 14:24:04 +0800 Subject: [PATCH] [Feature] Support kv nz feature for DeepSeek decode node in disagg-prefill scenario (#3072) By converting the KV cache from ND to NZ format when the decode node receives it, this PR ensures that the KV NZ feature works correctly during the decoding phase in disagg-prefill scenario. - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac --------- Signed-off-by: Jade Zheng Co-authored-by: ghphotoframe <854746559@qq.com> Co-authored-by: alex101-ops --- .../configuration/additional_config.md | 4 +- .../ops/singlecard_ops/test_mla_preprocess.py | 6 +- .../singlecard_ops/test_mla_preprocess_nq.py | 6 +- .../test_mla_preprocess_qdown.py | 6 +- tests/ut/test_ascend_config.py | 3 + vllm_ascend/ascend_config.py | 20 ++- vllm_ascend/attention/mla_v1.py | 56 +++++-- vllm_ascend/distributed/mooncake_connector.py | 157 ++++++++++-------- 8 files changed, 163 insertions(+), 95 deletions(-) diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index f8f398d6..f62f8b18 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -48,6 +48,7 @@ The following table lists additional configuration options available in vLLM Asc | `num_wait_worker_iterations` | int | `30` | The forward iterations when the EPLB worker will finish CPU tasks. In our test default value 30 can cover most cases. | | `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory. | | `init_redundancy_expert` | int | `0` | Specify redundant experts during initialization. | +| `enable_kv_nz` | bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). | The details of each configuration option are as follows: @@ -105,7 +106,8 @@ An example of additional configuration is as follows: "embedding_tensor_parallel_size": 8, "mlp_tensor_parallel_size": 8, }, + "enable_kv_nz": False, "multistream_overlap_shared_expert": True, - "refresh": False, + "refresh": False } ``` diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess.py index 99b383ba..6ef95213 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess.py @@ -1,5 +1,6 @@ import gc +import pytest import torch import torch_npu @@ -8,8 +9,9 @@ from vllm_ascend.utils import enable_custom_op enable_custom_op() +@pytest.mark.parametrize("cache_mode", ["krope_ctkv", "nzcache"]) @torch.inference_mode() -def test_mla_preprocess_kernel(): +def test_mla_preprocess_kernel(cache_mode: str): token_num = 1 head_num = 2 N_7168 = 7168 @@ -98,7 +100,7 @@ def test_mla_preprocess_kernel(): bias1=bias1, ctkv_scale=ctkv_scale, q_nope_scale=qnope_scale, - cache_mode="krope_ctkv", + cache_mode=cache_mode, quant_mode="per_tensor_quant_asymm", enable_inner_out=False, q_out0=q_nope_out, diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_nq.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_nq.py index b18c63f6..196ffafc 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_nq.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_nq.py @@ -1,5 +1,6 @@ import gc +import pytest import torch import torch_npu @@ -8,8 +9,9 @@ from vllm_ascend.utils import enable_custom_op enable_custom_op() +@pytest.mark.parametrize("cache_mode", ["krope_ctkv", "nzcache"]) @torch.inference_mode() -def test_mla_preprocess_kernel(): +def test_mla_preprocess_kernel(cache_mode: str): token_num = 1 head_num = 2 N_7168 = 7168 @@ -82,7 +84,7 @@ def test_mla_preprocess_kernel(): None, None, None, - cache_mode="krope_ctkv", + cache_mode=cache_mode, quant_mode="no_quant", enable_inner_out=False, q_out0=q_nope_out, diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_qdown.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_qdown.py index 9eb7e1ca..04753617 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_qdown.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_qdown.py @@ -1,5 +1,6 @@ import gc +import pytest import torch import torch_npu @@ -8,8 +9,9 @@ from vllm_ascend.utils import enable_custom_op enable_custom_op() +@pytest.mark.parametrize("cache_mode", ["krope_ctkv", "nzcache"]) @torch.inference_mode() -def test_mla_preprocess_kernel(): +def test_mla_preprocess_kernel(cache_mode: str): token_num = 1 head_num = 2 N_7168 = 7168 @@ -99,7 +101,7 @@ def test_mla_preprocess_kernel(): bias1=bias1, ctkv_scale=ctkv_scale, q_nope_scale=qnope_scale, - cache_mode="krope_ctkv", + cache_mode=cache_mode, quant_mode="per_tensor_quant_asymm", enable_inner_out=True, q_out0=q_nope_out, diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index 1a337dea..5cccc027 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -39,6 +39,7 @@ class TestAscendConfig(TestBase): ascend_config = init_ascend_config(test_vllm_config) self.assertIsNone(ascend_config.expert_map_path) self.assertFalse(ascend_config.multistream_overlap_shared_expert) + self.assertFalse(ascend_config.enable_kv_nz) ascend_compilation_config = ascend_config.ascend_compilation_config self.assertTrue(ascend_compilation_config.fuse_norm_quant) @@ -53,6 +54,7 @@ class TestAscendConfig(TestBase): "multistream_overlap_shared_expert": True, "expert_map_path": "test_expert_map_path", "refresh": True, + "enable_kv_nz": False } ascend_config = init_ascend_config(test_vllm_config) self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path") @@ -61,6 +63,7 @@ class TestAscendConfig(TestBase): ascend_compilation_config = ascend_config.ascend_compilation_config self.assertFalse(ascend_compilation_config.fuse_norm_quant) + self.assertFalse(ascend_config.enable_kv_nz) @_clean_up_ascend_config def test_init_ascend_config_enable_npugraph_ex(self): diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 8be434a1..fec3ade8 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -13,18 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import TYPE_CHECKING, Optional from vllm.logger import logger from vllm.triton_utils import HAS_TRITON +if TYPE_CHECKING: + from vllm.config import VllmConfig + class AscendConfig: """ Configuration Object for additional_config from vllm.configs. """ - def __init__(self, vllm_config): + def __init__(self, vllm_config: "VllmConfig"): additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} xlite_graph_config = additional_config.get("xlite_graph_config", {}) @@ -121,6 +124,19 @@ class AscendConfig: self.enable_async_exponential = bool( additional_config.get("enable_async_exponential", False)) + self.enable_kv_nz = additional_config.get("enable_kv_nz", False) + if self.enable_kv_nz: + use_sparse = hasattr(vllm_config.model_config.hf_config, + "index_topk") + if not vllm_config.model_config.is_deepseek_mla or use_sparse: + raise RuntimeError( + "enable_kv_nz is only supported for mla currently.") + if vllm_config.kv_transfer_config is None \ + or not vllm_config.kv_transfer_config.is_kv_consumer: + raise NotImplementedError( + "enable_kv_nz is only supported in pd scenario and can " + "only be used in D node.") + class FinegrainedTPConfig: """ diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 76f2e410..5535660e 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -745,6 +745,7 @@ class AscendMLAImpl(MLAAttentionImpl): ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_prefetch = ascend_config.weight_prefetch_config.enabled + self.enable_kv_nz = ascend_config.enable_kv_nz self.ring_mla_mask_size = 512 @@ -1073,7 +1074,7 @@ class AscendMLAImpl(MLAAttentionImpl): # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv_no_split = kv_no_split.view( B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA" + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, @@ -1143,37 +1144,57 @@ class AscendMLAImpl(MLAAttentionImpl): # shape of knope/k_pe for npu graph mode should be: # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] actual_seq_lengths = None - k_nope = k_nope.view(-1, self.num_kv_heads, block_size, - self.kv_lora_rank) - k_pe = k_pe.view(-1, self.num_kv_heads, block_size, - self.qk_rope_head_dim) + if self.enable_kv_nz: + nz_fmt_last_dim = 16 + k_nope = k_nope.view(-1, self.num_kv_heads, + self.kv_lora_rank // nz_fmt_last_dim, + block_size, nz_fmt_last_dim) + k_pe = k_pe.view(-1, self.num_kv_heads, + self.qk_rope_head_dim // nz_fmt_last_dim, + block_size, nz_fmt_last_dim) + else: + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, + self.kv_lora_rank) + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, + self.qk_rope_head_dim) + attn_output_shape: tuple | None = None if attn_metadata.attn_state in [ AscendAttentionState.SpecDecoding, AscendAttentionState.ChunkedPrefill, AscendAttentionState.DecodeOnly, ] and self.speculative_config is not None: - # Input shape: [num_tokens, num_heads, dim] - # Output shape: [num_heads, num_tokens, dim] # The right part layout indicates the layout of the attention # output. It is set to NTD to avoid the need for a transpose # operation after attention. input_layout = "TND_NTD" # TODO: If the driver is upgraded later, the contiguous function can be deleted. + # Input shape: [num_tokens, num_heads, dim] q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous() q_pe = q_pe.view(num_tokens, self.num_heads, -1) + # Output shape: [num_heads, num_tokens, dim] + attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank) sparse_mode = 3 spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore actual_seq_lengths = decode_meta.actual_seq_lengths_q else: - # Input shape: [num_reqs, num_heads, seq_len, dim] - # Output shape: [num_heads, num_reqs, seq_len, dim] # The output layout is set to NBSD to eliminate the need for a # transpose operation after attention. - input_layout = "BNSD_NBSD" - q_nope = q_nope.view(num_tokens, self.num_heads, 1, - -1).contiguous() - q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + if self.enable_kv_nz: + # Input shape: [num_tokens, seq_len, num_heads, dim] + input_layout = "BSND_NBSD" + q_nope = q_nope.view(num_tokens, 1, self.num_heads, + -1).contiguous() + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) + else: + # Input shape: [num_tokens, num_heads, seq_len, dim] + input_layout = "BNSD_NBSD" + q_nope = q_nope.view(num_tokens, self.num_heads, 1, + -1).contiguous() + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + # Output shape: [num_heads, num_tokens, seq_len, dim] + attn_output_shape = (self.num_heads, num_tokens, 1, + self.kv_lora_rank) sparse_mode = 0 spec_attn_mask = None @@ -1215,10 +1236,9 @@ class AscendMLAImpl(MLAAttentionImpl): else: update_graph_params_workspaces(num_tokens, workspace) - attn_output = torch.empty( - (q_nope.shape[1], q_nope.shape[0], *q_nope.shape[2:]), - dtype=q_nope.dtype, - device=q_nope.device) + attn_output = torch.empty(attn_output_shape, + dtype=q_nope.dtype, + device=q_nope.device) softmax_lse = torch.empty(num_tokens, dtype=q_nope.dtype, device=q_nope.device) @@ -1297,7 +1317,7 @@ class AscendMLAImpl(MLAAttentionImpl): bias1=self.qb_qt_bias, ctkv_scale=self.ctkv_scale, q_nope_scale=self.q_nope_scale, - cache_mode="krope_ctkv", + cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv", quant_mode="per_tensor_quant_asymm", q_out0=decode_q_nope, kv_cache_out0=decode_k_nope, diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 38284335..6cacef83 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -412,8 +412,10 @@ class KVCacheRecvingThread(threading.Thread): logger.debug( f"Finished transferring KV cache for request {request_id}.") except Exception as e: - logger.error("Failed to transfer KV cache for request " - f"{request_id}: {e}") + logger.error( + "Failed to transfer KV cache for request " + f"{request_id}: {e}", + exc_info=True) finally: # Always send the done signal to the remote host to ensure proper # resource cleanup. Failing to do so may cause a memory leak on the @@ -539,97 +541,116 @@ class KVCacheRecvingThread(threading.Thread): request_id, req_transfer_elapsed, num_transfer_groups, num_blocks, get_ip(), self.tp_rank, session_id) - # Determine if the current position is the offset position at the end of the KV transmission. + # Determine if the current position is the offset position at the end of + # the KV transmission. is_kv_transfer_end = ( global_offset == tp_num_need_pulls * self._prefill_pp_size - 1) need_cat_cache = tp_num_need_pulls > 1 and is_kv_transfer_end - # need_nz_cache maybe caused error in non-MLA models - if need_cat_cache: - self._cat_kv_cache(grouped_local_block_ids, tp_num_need_pulls) + need_nz_cache = get_ascend_config().enable_kv_nz and is_kv_transfer_end + if need_nz_cache or need_cat_cache: + self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, + need_cat_cache, need_nz_cache) - def _cat_kv_cache(self, block_ids: list[list[int]], - tp_num_need_pulls: int): + def reformat_kv_cache(self, + block_ids: list[list[int]], + tp_num_need_pulls: int, + need_cat_cache: bool = False, + need_nz_cache: bool = False): # Get necessary parameters k_cache = list(self.kv_caches.values())[0][0] dtype = k_cache.dtype device = k_cache.device - head_dim = self.model_config.hf_text_config.head_dim - block_size = self.vllm_config.cache_config.block_size - num_kv_head = max( - self.model_config.hf_text_config.num_key_value_heads // - self.tp_size, 1) flat_block_ids = [item for sublist in block_ids for item in sublist] - block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32) + block_ids_tensor = torch.tensor(flat_block_ids, + dtype=torch.int32, + device=device) num_blocks = len(flat_block_ids) - block_len = num_blocks * block_size + num_tokens = num_blocks * self.block_size # Create device tensors for copy operations - block_table = block_ids_tensor.view(1, -1).to(device=device) - block_len_tensor = torch.tensor([block_len], - dtype=torch.int32).to(device=device) - seq_start_tensor = torch.tensor([0], - dtype=torch.int32).to(device=device) + block_table = block_ids_tensor.view(1, -1) + block_len_tensor = torch.tensor([num_tokens], + dtype=torch.int32, + device=device) + seq_start_tensor = torch.tensor([0], dtype=torch.int32, device=device) # Initialize buffers - k_buffer = torch.empty(block_len, - num_kv_head, - head_dim, - dtype=dtype, - device=device) - v_buffer = torch.empty(block_len, - num_kv_head, - head_dim, - dtype=dtype, - device=device) + k_buffer = torch.empty( + (num_tokens, self.num_kv_heads, self.k_head_dim), + dtype=dtype, + device=device) + v_buffer = torch.empty( + (num_tokens, self.num_kv_heads, self.v_head_dim), + dtype=dtype, + device=device) # Create slot mapping for reshape operations - block_offsets = torch.arange(0, block_size, dtype=torch.int32) + block_offsets = torch.arange(0, + self.block_size, + dtype=torch.int32, + device=device) slot_mapping = (block_offsets.reshape( - (1, block_size)) + block_ids_tensor.reshape( - (num_blocks, 1)) * block_size) - slot_mapping = slot_mapping.flatten().to(device=device) + (1, self.block_size)) + block_ids_tensor.reshape( + (num_blocks, 1)) * self.block_size).flatten() + + # FIXME: Right now, if we skip synchronization at this point, the system + # will crash in GQA scenarios. However, we still haven't identified the + # root cause. + torch.npu.synchronize() # Process each layer in the KV cache for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items(): # Load cache data into buffers - torch_npu.atb.npu_paged_cache_load( - k_cache_layer, - v_cache_layer, - block_table, - block_len_tensor, - seq_starts=seq_start_tensor, - key=k_buffer, - value=v_buffer, - ) - - # Transpose KV cache - k_buffer = self._transpose_kv_cache_between_head( - k_buffer, num_blocks, block_size, block_len, num_kv_head, - tp_num_need_pulls) - v_buffer = self._transpose_kv_cache_between_head( - v_buffer, num_blocks, block_size, block_len, num_kv_head, - tp_num_need_pulls) - - # Reshape and cache the processed buffers - torch_npu._npu_reshape_and_cache( - key=k_buffer, - value=v_buffer, - key_cache=k_cache_layer, - value_cache=v_cache_layer, - slot_indices=slot_mapping, - ) - + torch_npu.atb.npu_paged_cache_load(k_cache_layer, + v_cache_layer, + block_table, + block_len_tensor, + seq_starts=seq_start_tensor, + key=k_buffer, + value=v_buffer) + if need_cat_cache: + self._cat_kv_cache(k_cache_layer, v_cache_layer, k_buffer, + v_buffer, tp_num_need_pulls, num_blocks, + num_tokens, slot_mapping) + if need_nz_cache: + self._nz_kv_cache(k_cache_layer, v_cache_layer, k_buffer, + v_buffer, slot_mapping) # Clean up buffers del k_buffer, v_buffer - def _transpose_kv_cache_between_head( - self, buffer: torch.Tensor, num_blocks: int, block_size: int, - block_len: int, num_kv_head: int, - tp_num_need_pulls: int) -> torch.Tensor: - buffer = buffer.view(num_blocks, tp_num_need_pulls, block_size, -1) - buffer.transpose_(1, 2) - return buffer.contiguous().view(block_len, num_kv_head, -1) + def _cat_kv_cache(self, k_cache_layer, v_cache_layer, k_buffer, v_buffer, + tp_num_need_pulls, num_blocks, num_tokens, slot_mapping): + + def _transpose_kv_cache_between_head( + buffer: torch.Tensor) -> torch.Tensor: + buffer = buffer.view(num_blocks, tp_num_need_pulls, + self.block_size, -1) + buffer.transpose_(1, 2) + return buffer.contiguous().view(num_tokens, self.num_kv_heads, -1) + + # Transpose KV cache + k_buffer = _transpose_kv_cache_between_head(k_buffer) + v_buffer = _transpose_kv_cache_between_head(v_buffer) + + # Reshape and cache the processed buffers + torch_npu._npu_reshape_and_cache(key=k_buffer, + value=v_buffer, + key_cache=k_cache_layer, + value_cache=v_cache_layer, + slot_indices=slot_mapping) + + def _nz_kv_cache(self, k_cache_layer, v_cache_layer, k_buffer, v_buffer, + slot_mapping): + nz_fmt_last_dim = 16 + k_cache_layer = k_cache_layer.view( + -1, self.k_head_dim * self.num_kv_heads // nz_fmt_last_dim, + self.block_size, nz_fmt_last_dim) + v_cache_layer = v_cache_layer.view( + -1, self.v_head_dim * self.num_kv_heads // nz_fmt_last_dim, + self.block_size, nz_fmt_last_dim) + torch_npu.npu_scatter_pa_kv_cache(k_buffer, v_buffer, k_cache_layer, + v_cache_layer, slot_mapping) def _get_remote_metadata(self, remote_host: str, remote_handshake_port: int) -> None: