[Disaggregated Prefill] P2P Disaggregated Prefill based on llm_datadist (#694)

### What this PR does / why we need it?
- This PR proposes a P2P version of Disaggregated Prefill based on
llm_datadist which manages data transfer.

- This solution reconstructs previous offline single-node Disaggregated
Prefill solution, and supports multi-node and online serveing now.

- Currently this solution supports 1P1D situation of Deepseek hybrid
parallelism (P: TP+EP, D: DP+EP). Note that xPyD situation is considered
in the solution design, and will be supported soon within v1 engine.

---------

Signed-off-by: hw_whx <wanghexiang7@huawei.com>
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Co-authored-by: hw_whx <wanghexiang7@huawei.com>
Co-authored-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
whx
2025-05-01 22:31:36 +08:00
committed by GitHub
parent 84e2ed898b
commit 8b194ad12e
18 changed files with 1769 additions and 32 deletions

View File

@@ -33,7 +33,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group
from vllm.distributed import get_dp_group, get_pp_group
from vllm.distributed.kv_transfer import get_kv_transfer_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
@@ -1343,6 +1343,17 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
kv_caches=kv_caches
)
if get_dp_group().world_size > 1:
bypass_model_exec_tensor = torch.tensor(
1, dtype=torch.int32) if bypass_model_exec else torch.tensor(
0, dtype=torch.int32)
torch.distributed.all_reduce(bypass_model_exec_tensor,
op=torch.distributed.ReduceOp.MIN,
group=get_dp_group().cpu_group)
# If there is any group have not receive the necessary hidden states or kv_cache, we force all the dp group execute.
if bypass_model_exec_tensor.item() == 0:
bypass_model_exec = False
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
@@ -1399,10 +1410,21 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
torch.tensor(model_forward_time +
orig_model_forward_time))
return hidden_or_intermediate_states
# TODO: remove the synchronize here
torch.npu.synchronize()
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
# Sending KV cache in distributed KV cache transfer setting
if self.need_send_kv(model_input, kv_caches):
get_kv_transfer_group().send_kv_caches_and_hidden_states(
# model_executable is used to know which layer the current
# worker is working on, so that we can send KV for only those
# layers.
model_executable,
model_input,
kv_caches,
hidden_or_intermediate_states,
)
if not self.is_driver_worker:
return []