From ab928ed586f2881d455009a28ca14489576ccf2a Mon Sep 17 00:00:00 2001 From: zxr2333 <64738772+nwpu-zxr@users.noreply.github.com> Date: Tue, 31 Mar 2026 09:25:22 +0800 Subject: [PATCH] [v0.18.0][P/D][Feature]Layerwise connector supports Mamba prefill prefix caching (#7796) ### What this PR does / why we need it? Mooncake layerwise connector supports Mamba prefix caching on prefiller nodes. ### Does this PR introduce _any_ user-facing change? Yes. Use `--enable-prefix-caching` and `--mamba-cache-mode align` to enable mamba align mode prefix caching on P/D prefill nodes. This function does not supports on decode nodes now. ### How was this patch tested? By P/D E2E test. --------- Signed-off-by: nwpu-zxr --- .../kv_p2p/mooncake_layerwise_connector.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index e5de1133..2bcc9288 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -226,6 +226,12 @@ class KVCacheSendingLayerThread(threading.Thread): self.kv_cache_config = kv_cache_config self.kv_cache_specs = kv_cache_specs self.attn_resharding_group_idx = attn_resharding_group_idx + self.mamba_cache_mode = self.vllm_config.cache_config.mamba_cache_mode + self.num_speculative_tokens = ( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config is not None + else 0 + ) self.tp_size = tp_size self.tp_rank = tp_rank self.pd_head_ratio = pd_head_ratio @@ -274,6 +280,10 @@ class KVCacheSendingLayerThread(threading.Thread): if isinstance(layer_kv_cache_spec, MambaSpec): # only support one block transfer for mamba + if self.mamba_cache_mode == "align": + transfer_block_idx = len(local_block_ids) - self.num_speculative_tokens - 1 + else: + transfer_block_idx = 0 local_conv_addr, local_ssm_addr = local_layer_metadata.kv_caches_base_addr remote_conv_addr, remote_ssm_addr = remote_layer_metadata.kv_caches_base_addr local_conv_len, local_ssm_len = local_layer_metadata.block_len @@ -281,8 +291,8 @@ class KVCacheSendingLayerThread(threading.Thread): if tp_ratio == 1: src_list.extend( [ - local_conv_addr + local_block_ids[0] * local_conv_len, - local_ssm_addr + local_block_ids[0] * local_ssm_len, + local_conv_addr + local_block_ids[transfer_block_idx] * local_conv_len, + local_ssm_addr + local_block_ids[transfer_block_idx] * local_ssm_len, ] ) dst_list.extend( @@ -320,12 +330,14 @@ class KVCacheSendingLayerThread(threading.Thread): (i * conv_shape[1] + local_conv_offset) * tp_ratio + (self.tp_rank % tp_ratio) * local_conv_size ) * get_dtype_size(conv_dtype) - src_list.append(local_conv_addr + local_block_ids[0] * local_conv_len + local_addr_offset) + src_list.append( + local_conv_addr + local_block_ids[transfer_block_idx] * local_conv_len + local_addr_offset + ) dst_list.append(remote_conv_addr + remote_block_ids[0] * remote_conv_len + remote_addr_offset) length_list.append(local_conv_size * get_dtype_size(conv_dtype)) # ssm remote_addr_offset = (self.tp_rank % tp_ratio) * math.prod(ssm_shape) * get_dtype_size(ssm_dtype) - src_list.append(local_ssm_addr + local_block_ids[0] * local_ssm_len) + src_list.append(local_ssm_addr + local_block_ids[transfer_block_idx] * local_ssm_len) dst_list.append(remote_ssm_addr + remote_block_ids[0] * remote_ssm_len + remote_addr_offset) length_list.append(local_ssm_len) else: