[PD]: Support Muti Prefill in one node (#5704)

Co-authored-by: shuaills <shishuaiuoe@gmail.com>
This commit is contained in:
IAN
2025-04-26 00:30:47 +08:00
committed by GitHub
parent 50eda8398e
commit 11e27d0926
6 changed files with 55 additions and 9 deletions

View File

@@ -97,6 +97,7 @@ class GenerateReqInput:
# For disaggregated inference
bootstrap_host: Optional[Union[List[str], str]] = None
bootstrap_port: Optional[Union[List[int], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None
def normalize_batch_and_arguments(self):
@@ -400,6 +401,9 @@ class GenerateReqInput:
bootstrap_host=(
self.bootstrap_host[i] if self.bootstrap_host is not None else None
),
bootstrap_port=(
self.bootstrap_port[i] if self.bootstrap_port is not None else None
),
bootstrap_room=(
self.bootstrap_room[i] if self.bootstrap_room is not None else None
),
@@ -447,6 +451,7 @@ class TokenizedGenerateReqInput:
# For disaggregated inference
bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None

View File

@@ -391,6 +391,7 @@ class Req:
return_hidden_states: bool = False,
eos_token_ids: Optional[Set[int]] = None,
bootstrap_host: Optional[str] = None,
bootstrap_port: Optional[int] = None,
bootstrap_room: Optional[int] = None,
):
# Input and output info
@@ -526,6 +527,7 @@ class Req:
# For disaggregation
self.bootstrap_host: str = bootstrap_host
self.bootstrap_port: Optional[int] = bootstrap_port
self.bootstrap_room: Optional[int] = bootstrap_room
self.disagg_kv_sender: Optional[BaseKVSender] = None

View File

@@ -791,6 +791,7 @@ class Scheduler(
return_hidden_states=recv_req.return_hidden_states,
eos_token_ids=self.model_config.hf_eos_token_id,
bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port,
bootstrap_room=recv_req.bootstrap_room,
)
req.tokenizer = self.tokenizer

View File

@@ -498,6 +498,7 @@ class TokenizerManager:
token_ids_logprob,
obj.stream,
bootstrap_host=obj.bootstrap_host,
bootstrap_port=obj.bootstrap_port,
bootstrap_room=obj.bootstrap_room,
lora_path=obj.lora_path,
input_embeds=input_embeds,