Disaggregate prefill for kv cache register style (#950)

### What this PR does / why we need it?
This PR adopt `LLMDataDist` for kv cache register and `pull_blocks`
style disaggregate prefill implementation. The interface implementation
mainly follows the design of NIXL PR
https://github.com/vllm-project/vllm/pull/17751/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953
.

This PR can be test with the following step:
- Generate the rank table for all machine.
- execute`toy_proxy.py` to launch the disaggregate prefill proxy server,
specify the prefill ip, port and the decode ip, port
- Run the prefill server and decode server.
- send the request to the disaggregate prefill proxy

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2

---------

Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Signed-off-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Signed-off-by: liziyu179 <3475441767@qq.com>
Signed-off-by: underfitc <hucong24@huawei.com>
Signed-off-by: zouyida2052 <zouyida@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: underfituu <hzhucong@163.com>
Co-authored-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Co-authored-by: liziyu179 <3475441767@qq.com>
Co-authored-by: underfitc <hucong24@huawei.com>
Co-authored-by: zouyida2052 <zouyida@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: underfituu <hzhucong@163.com>
This commit is contained in:
Pleaplusone
2025-07-26 17:15:47 +08:00
committed by GitHub
parent 17a430f7b8
commit df0ec55162
28 changed files with 2833 additions and 144 deletions

View File

@@ -32,7 +32,8 @@ import torch_npu
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group, split_tensor_along_last_dim,
@@ -363,6 +364,10 @@ class CustomDeepseekV2MoE(nn.Module):
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()
self.kv_consumer = None
transfer_config = get_current_vllm_config().kv_transfer_config
if transfer_config is not None:
self.kv_consumer = transfer_config.kv_role == "kv_consumer"
self.params_dtype = torch.get_default_dtype()
self.rm_router_logits = self.experts.rm_router_logits
@@ -386,6 +391,11 @@ class CustomDeepseekV2MoE(nn.Module):
enable_force_load_balance = False
if hasattr(attn_metadata, 'with_prefill_across_dp'):
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
# If this node is kv_consumer, we force the moe always runs in decode path to make sure
# the behaviour aligned between dummy_run and normal model_execute.
if self.kv_consumer:
is_prefill = False
enable_force_load_balance = False
# router_logits: (num_tokens, n_experts)
router_logits = None