[Disaggregated Prefill] P2P Disaggregated Prefill based on llm_datadist (#694)

### What this PR does / why we need it?
- This PR proposes a P2P version of Disaggregated Prefill based on
llm_datadist which manages data transfer.

- This solution reconstructs previous offline single-node Disaggregated
Prefill solution, and supports multi-node and online serveing now.

- Currently this solution supports 1P1D situation of Deepseek hybrid
parallelism (P: TP+EP, D: DP+EP). Note that xPyD situation is considered
in the solution design, and will be supported soon within v1 engine.

---------

Signed-off-by: hw_whx <wanghexiang7@huawei.com>
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Co-authored-by: hw_whx <wanghexiang7@huawei.com>
Co-authored-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
whx
2025-05-01 22:31:36 +08:00
committed by GitHub
parent 84e2ed898b
commit 8b194ad12e
18 changed files with 1769 additions and 32 deletions

View File

@@ -18,10 +18,13 @@
#
import gc
import os
from typing import Dict, List, Optional, Set, Tuple, Type, Union
import msgpack # type: ignore
import torch
import torch.distributed
import zmq
from torch import nn
from vllm import envs
from vllm.config import VllmConfig, set_current_vllm_config
@@ -37,7 +40,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.utils import GiB_bytes, bind_kv_cache
from vllm.utils import GiB_bytes, bind_kv_cache, get_ip
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner_base import ModelRunnerBase
@@ -157,6 +160,33 @@ class NPUWorker(LocalOrDistributedWorkerBase):
else:
self.profiler = None
self.enable_dummy_run = False
if os.getenv("VLLM_DP_PROXY_IP", None):
logger.warning("enable dummy run for the DP")
self.enable_dummy_run = True
# dp_rank = os.environ["VLLM_DP_RANK"]
dp_master_ip = os.environ["VLLM_DP_PROXY_IP"]
dp_proxy_listener_port = os.environ["VLLM_DP_PROXY_PORT"]
dp_proxy_monitor_port = os.environ["VLLM_DP_MONITOR_PORT"]
dp_proxy_listener_addr = f"{dp_master_ip}:{dp_proxy_listener_port}"
self.dp_proxy_monitor_addr = f"{dp_master_ip}:{dp_proxy_monitor_port}"
http_ip = get_ip()
port = os.environ["VLLM_HTTP_PORT"]
self.http_addr = f"{http_ip}:{port}"
context = zmq.Context() # type: ignore
sock = context.socket(zmq.DEALER) # type: ignore
logger.debug("ping dp proxy start, DP_RANK:%s", 0)
# logger.debug("ping dp proxy start, DP_RANK:%s", dp_rank)
sock.connect(f"tcp://{dp_proxy_listener_addr}")
data = {"type": "DP", "http_address": self.http_addr}
for _ in range(10):
sock.send(msgpack.dumps(data))
self.notify_socket = context.socket(zmq.PUSH) # type: ignore
self.notify_socket.connect(f"tcp://{self.dp_proxy_monitor_addr}")
def sleep(self, level: int = 1) -> None:
NPUPlatform.set_device(self.device)
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
@@ -375,6 +405,11 @@ class NPUWorker(LocalOrDistributedWorkerBase):
@torch.inference_mode()
def execute_worker(self, worker_input: WorkerInput) -> None:
if self.enable_dummy_run:
logger.debug(
f"send notify to the dp proxy: {self.dp_proxy_monitor_addr}")
data = {"info": "notify_step", "http_address": self.http_addr}
self.notify_socket.send(msgpack.dumps(data))
virtual_engine = worker_input.virtual_engine
# Issue cache operations.
if (worker_input.blocks_to_swap_in is not None