[v0.18.0][P/D][Feature]Layerwise connector supports Mamba prefill prefix caching (#7796)

### What this PR does / why we need it?
Mooncake layerwise connector supports Mamba prefix caching on prefiller
nodes.

### Does this PR introduce _any_ user-facing change?
Yes. Use `--enable-prefix-caching` and `--mamba-cache-mode align` to
enable mamba align mode prefix caching on P/D prefill nodes. This
function does not supports on decode nodes now.

### How was this patch tested?
By P/D E2E test.

---------

Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
This commit is contained in:
zxr2333
2026-03-31 09:25:22 +08:00
committed by GitHub
parent cab5d73633
commit ab928ed586

View File

@@ -226,6 +226,12 @@ class KVCacheSendingLayerThread(threading.Thread):
self.kv_cache_config = kv_cache_config
self.kv_cache_specs = kv_cache_specs
self.attn_resharding_group_idx = attn_resharding_group_idx
self.mamba_cache_mode = self.vllm_config.cache_config.mamba_cache_mode
self.num_speculative_tokens = (
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config is not None
else 0
)
self.tp_size = tp_size
self.tp_rank = tp_rank
self.pd_head_ratio = pd_head_ratio
@@ -274,6 +280,10 @@ class KVCacheSendingLayerThread(threading.Thread):
if isinstance(layer_kv_cache_spec, MambaSpec):
# only support one block transfer for mamba
if self.mamba_cache_mode == "align":
transfer_block_idx = len(local_block_ids) - self.num_speculative_tokens - 1
else:
transfer_block_idx = 0
local_conv_addr, local_ssm_addr = local_layer_metadata.kv_caches_base_addr
remote_conv_addr, remote_ssm_addr = remote_layer_metadata.kv_caches_base_addr
local_conv_len, local_ssm_len = local_layer_metadata.block_len
@@ -281,8 +291,8 @@ class KVCacheSendingLayerThread(threading.Thread):
if tp_ratio == 1:
src_list.extend(
[
local_conv_addr + local_block_ids[0] * local_conv_len,
local_ssm_addr + local_block_ids[0] * local_ssm_len,
local_conv_addr + local_block_ids[transfer_block_idx] * local_conv_len,
local_ssm_addr + local_block_ids[transfer_block_idx] * local_ssm_len,
]
)
dst_list.extend(
@@ -320,12 +330,14 @@ class KVCacheSendingLayerThread(threading.Thread):
(i * conv_shape[1] + local_conv_offset) * tp_ratio
+ (self.tp_rank % tp_ratio) * local_conv_size
) * get_dtype_size(conv_dtype)
src_list.append(local_conv_addr + local_block_ids[0] * local_conv_len + local_addr_offset)
src_list.append(
local_conv_addr + local_block_ids[transfer_block_idx] * local_conv_len + local_addr_offset
)
dst_list.append(remote_conv_addr + remote_block_ids[0] * remote_conv_len + remote_addr_offset)
length_list.append(local_conv_size * get_dtype_size(conv_dtype))
# ssm
remote_addr_offset = (self.tp_rank % tp_ratio) * math.prod(ssm_shape) * get_dtype_size(ssm_dtype)
src_list.append(local_ssm_addr + local_block_ids[0] * local_ssm_len)
src_list.append(local_ssm_addr + local_block_ids[transfer_block_idx] * local_ssm_len)
dst_list.append(remote_ssm_addr + remote_block_ids[0] * remote_ssm_len + remote_addr_offset)
length_list.append(local_ssm_len)
else: