[PD] Support KV transfer with mooncake (#4880)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Co-authored-by: Shangming Cai <caishangming@linux.alibaba.com>
Co-authored-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Co-authored-by: shangmingc <csmthu@gmail.com>
This commit is contained in:
Teng Ma
2025-04-10 14:23:23 +08:00
committed by GitHub
parent f730362ee2
commit 4c31ae9f6d
8 changed files with 571 additions and 30 deletions

View File

@@ -95,6 +95,10 @@ class GenerateReqInput:
# Whether to return hidden states
return_hidden_states: bool = False
# For disaggregated inference
bootstrap_host: Optional[str] = None
bootstrap_room: Optional[int] = None
def normalize_batch_and_arguments(self):
"""
Normalize the batch size and arguments for the request.
@@ -435,6 +439,10 @@ class TokenizedGenerateReqInput:
# Whether to return hidden states
return_hidden_states: bool = False
# For disaggregated inference
bootstrap_host: Optional[str] = None
bootstrap_room: Optional[int] = None
@dataclass
class EmbeddingReqInput:

View File

@@ -390,6 +390,8 @@ class Req:
custom_logit_processor: Optional[str] = None,
return_hidden_states: bool = False,
eos_token_ids: Optional[Set[int]] = None,
bootstrap_host: Optional[str] = None,
bootstrap_room: Optional[int] = None,
):
# Input and output info
self.rid = rid
@@ -523,8 +525,8 @@ class Req:
self.lora_path = lora_path
# For disaggregation
self.bootstrap_host: str = "0.0.0.0"
self.bootstrap_room: Optional[int] = None
self.bootstrap_host: str = bootstrap_host
self.bootstrap_room: Optional[int] = bootstrap_room
self.disagg_kv_sender: Optional[KVSender] = None
# used for warmup because we don't have a pair yet when init

View File

@@ -836,6 +836,8 @@ class Scheduler(
custom_logit_processor=custom_logit_processor,
return_hidden_states=recv_req.return_hidden_states,
eos_token_ids=self.model_config.hf_eos_token_id,
bootstrap_host=recv_req.bootstrap_host,
bootstrap_room=recv_req.bootstrap_room,
)
req.tokenizer = self.tokenizer
req.queue_time_start = time.time()

View File

@@ -452,6 +452,8 @@ class TokenizerManager:
top_logprobs_num,
token_ids_logprob,
obj.stream,
bootstrap_host=obj.bootstrap_host,
bootstrap_room=obj.bootstrap_room,
lora_path=obj.lora_path,
input_embeds=input_embeds,
session_params=session_params,