[Disaggregated Prefill] P2P Disaggregated Prefill based on llm_datadist (#694)
### What this PR does / why we need it? - This PR proposes a P2P version of Disaggregated Prefill based on llm_datadist which manages data transfer. - This solution reconstructs previous offline single-node Disaggregated Prefill solution, and supports multi-node and online serveing now. - Currently this solution supports 1P1D situation of Deepseek hybrid parallelism (P: TP+EP, D: DP+EP). Note that xPyD situation is considered in the solution design, and will be supported soon within v1 engine. --------- Signed-off-by: hw_whx <wanghexiang7@huawei.com> Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Co-authored-by: hw_whx <wanghexiang7@huawei.com> Co-authored-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
@@ -33,7 +33,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.distributed import get_dp_group, get_pp_group
|
||||
from vllm.distributed.kv_transfer import get_kv_transfer_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
@@ -1343,6 +1343,17 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
kv_caches=kv_caches
|
||||
)
|
||||
|
||||
if get_dp_group().world_size > 1:
|
||||
bypass_model_exec_tensor = torch.tensor(
|
||||
1, dtype=torch.int32) if bypass_model_exec else torch.tensor(
|
||||
0, dtype=torch.int32)
|
||||
torch.distributed.all_reduce(bypass_model_exec_tensor,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=get_dp_group().cpu_group)
|
||||
# If there is any group have not receive the necessary hidden states or kv_cache, we force all the dp group execute.
|
||||
if bypass_model_exec_tensor.item() == 0:
|
||||
bypass_model_exec = False
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
seqlen_agnostic_kwargs = {
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
@@ -1399,10 +1410,21 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
torch.tensor(model_forward_time +
|
||||
orig_model_forward_time))
|
||||
return hidden_or_intermediate_states
|
||||
# TODO: remove the synchronize here
|
||||
torch.npu.synchronize()
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Sending KV cache in distributed KV cache transfer setting
|
||||
if self.need_send_kv(model_input, kv_caches):
|
||||
get_kv_transfer_group().send_kv_caches_and_hidden_states(
|
||||
# model_executable is used to know which layer the current
|
||||
# worker is working on, so that we can send KV for only those
|
||||
# layers.
|
||||
model_executable,
|
||||
model_input,
|
||||
kv_caches,
|
||||
hidden_or_intermediate_states,
|
||||
)
|
||||
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
@@ -18,10 +18,13 @@
|
||||
#
|
||||
|
||||
import gc
|
||||
import os
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import msgpack # type: ignore
|
||||
import torch
|
||||
import torch.distributed
|
||||
import zmq
|
||||
from torch import nn
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
@@ -37,7 +40,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SequenceGroupMetadata, SequenceGroupMetadataDelta)
|
||||
from vllm.utils import GiB_bytes, bind_kv_cache
|
||||
from vllm.utils import GiB_bytes, bind_kv_cache, get_ip
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase
|
||||
@@ -157,6 +160,33 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
self.enable_dummy_run = False
|
||||
if os.getenv("VLLM_DP_PROXY_IP", None):
|
||||
logger.warning("enable dummy run for the DP")
|
||||
self.enable_dummy_run = True
|
||||
# dp_rank = os.environ["VLLM_DP_RANK"]
|
||||
dp_master_ip = os.environ["VLLM_DP_PROXY_IP"]
|
||||
dp_proxy_listener_port = os.environ["VLLM_DP_PROXY_PORT"]
|
||||
dp_proxy_monitor_port = os.environ["VLLM_DP_MONITOR_PORT"]
|
||||
dp_proxy_listener_addr = f"{dp_master_ip}:{dp_proxy_listener_port}"
|
||||
self.dp_proxy_monitor_addr = f"{dp_master_ip}:{dp_proxy_monitor_port}"
|
||||
http_ip = get_ip()
|
||||
port = os.environ["VLLM_HTTP_PORT"]
|
||||
self.http_addr = f"{http_ip}:{port}"
|
||||
context = zmq.Context() # type: ignore
|
||||
sock = context.socket(zmq.DEALER) # type: ignore
|
||||
|
||||
logger.debug("ping dp proxy start, DP_RANK:%s", 0)
|
||||
# logger.debug("ping dp proxy start, DP_RANK:%s", dp_rank)
|
||||
|
||||
sock.connect(f"tcp://{dp_proxy_listener_addr}")
|
||||
data = {"type": "DP", "http_address": self.http_addr}
|
||||
for _ in range(10):
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
self.notify_socket = context.socket(zmq.PUSH) # type: ignore
|
||||
self.notify_socket.connect(f"tcp://{self.dp_proxy_monitor_addr}")
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
NPUPlatform.set_device(self.device)
|
||||
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
|
||||
@@ -375,6 +405,11 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||
if self.enable_dummy_run:
|
||||
logger.debug(
|
||||
f"send notify to the dp proxy: {self.dp_proxy_monitor_addr}")
|
||||
data = {"info": "notify_step", "http_address": self.http_addr}
|
||||
self.notify_socket.send(msgpack.dumps(data))
|
||||
virtual_engine = worker_input.virtual_engine
|
||||
# Issue cache operations.
|
||||
if (worker_input.blocks_to_swap_in is not None
|
||||
|
||||
Reference in New Issue
Block a user