[v0.18.0][P/D][Feature]Layerwise connector supports Mamba prefill prefix caching (#7796)
### What this PR does / why we need it? Mooncake layerwise connector supports Mamba prefix caching on prefiller nodes. ### Does this PR introduce _any_ user-facing change? Yes. Use `--enable-prefix-caching` and `--mamba-cache-mode align` to enable mamba align mode prefix caching on P/D prefill nodes. This function does not supports on decode nodes now. ### How was this patch tested? By P/D E2E test. --------- Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
This commit is contained in:
@@ -226,6 +226,12 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.kv_cache_specs = kv_cache_specs
|
||||
self.attn_resharding_group_idx = attn_resharding_group_idx
|
||||
self.mamba_cache_mode = self.vllm_config.cache_config.mamba_cache_mode
|
||||
self.num_speculative_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config is not None
|
||||
else 0
|
||||
)
|
||||
self.tp_size = tp_size
|
||||
self.tp_rank = tp_rank
|
||||
self.pd_head_ratio = pd_head_ratio
|
||||
@@ -274,6 +280,10 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
|
||||
if isinstance(layer_kv_cache_spec, MambaSpec):
|
||||
# only support one block transfer for mamba
|
||||
if self.mamba_cache_mode == "align":
|
||||
transfer_block_idx = len(local_block_ids) - self.num_speculative_tokens - 1
|
||||
else:
|
||||
transfer_block_idx = 0
|
||||
local_conv_addr, local_ssm_addr = local_layer_metadata.kv_caches_base_addr
|
||||
remote_conv_addr, remote_ssm_addr = remote_layer_metadata.kv_caches_base_addr
|
||||
local_conv_len, local_ssm_len = local_layer_metadata.block_len
|
||||
@@ -281,8 +291,8 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
if tp_ratio == 1:
|
||||
src_list.extend(
|
||||
[
|
||||
local_conv_addr + local_block_ids[0] * local_conv_len,
|
||||
local_ssm_addr + local_block_ids[0] * local_ssm_len,
|
||||
local_conv_addr + local_block_ids[transfer_block_idx] * local_conv_len,
|
||||
local_ssm_addr + local_block_ids[transfer_block_idx] * local_ssm_len,
|
||||
]
|
||||
)
|
||||
dst_list.extend(
|
||||
@@ -320,12 +330,14 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
(i * conv_shape[1] + local_conv_offset) * tp_ratio
|
||||
+ (self.tp_rank % tp_ratio) * local_conv_size
|
||||
) * get_dtype_size(conv_dtype)
|
||||
src_list.append(local_conv_addr + local_block_ids[0] * local_conv_len + local_addr_offset)
|
||||
src_list.append(
|
||||
local_conv_addr + local_block_ids[transfer_block_idx] * local_conv_len + local_addr_offset
|
||||
)
|
||||
dst_list.append(remote_conv_addr + remote_block_ids[0] * remote_conv_len + remote_addr_offset)
|
||||
length_list.append(local_conv_size * get_dtype_size(conv_dtype))
|
||||
# ssm
|
||||
remote_addr_offset = (self.tp_rank % tp_ratio) * math.prod(ssm_shape) * get_dtype_size(ssm_dtype)
|
||||
src_list.append(local_ssm_addr + local_block_ids[0] * local_ssm_len)
|
||||
src_list.append(local_ssm_addr + local_block_ids[transfer_block_idx] * local_ssm_len)
|
||||
dst_list.append(remote_ssm_addr + remote_block_ids[0] * remote_ssm_len + remote_addr_offset)
|
||||
length_list.append(local_ssm_len)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user