[PD] Support KV transfer with mooncake (#4880)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com> Co-authored-by: Shangming Cai <caishangming@linux.alibaba.com> Co-authored-by: Xuchun Shang <xuchun.shang@linux.alibaba.com> Co-authored-by: shangmingc <csmthu@gmail.com>
This commit is contained in:
@@ -95,6 +95,10 @@ class GenerateReqInput:
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: bool = False
|
||||
|
||||
# For disaggregated inference
|
||||
bootstrap_host: Optional[str] = None
|
||||
bootstrap_room: Optional[int] = None
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
"""
|
||||
Normalize the batch size and arguments for the request.
|
||||
@@ -435,6 +439,10 @@ class TokenizedGenerateReqInput:
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: bool = False
|
||||
|
||||
# For disaggregated inference
|
||||
bootstrap_host: Optional[str] = None
|
||||
bootstrap_room: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingReqInput:
|
||||
|
||||
@@ -390,6 +390,8 @@ class Req:
|
||||
custom_logit_processor: Optional[str] = None,
|
||||
return_hidden_states: bool = False,
|
||||
eos_token_ids: Optional[Set[int]] = None,
|
||||
bootstrap_host: Optional[str] = None,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
):
|
||||
# Input and output info
|
||||
self.rid = rid
|
||||
@@ -523,8 +525,8 @@ class Req:
|
||||
self.lora_path = lora_path
|
||||
|
||||
# For disaggregation
|
||||
self.bootstrap_host: str = "0.0.0.0"
|
||||
self.bootstrap_room: Optional[int] = None
|
||||
self.bootstrap_host: str = bootstrap_host
|
||||
self.bootstrap_room: Optional[int] = bootstrap_room
|
||||
self.disagg_kv_sender: Optional[KVSender] = None
|
||||
|
||||
# used for warmup because we don't have a pair yet when init
|
||||
|
||||
@@ -836,6 +836,8 @@ class Scheduler(
|
||||
custom_logit_processor=custom_logit_processor,
|
||||
return_hidden_states=recv_req.return_hidden_states,
|
||||
eos_token_ids=self.model_config.hf_eos_token_id,
|
||||
bootstrap_host=recv_req.bootstrap_host,
|
||||
bootstrap_room=recv_req.bootstrap_room,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
req.queue_time_start = time.time()
|
||||
|
||||
@@ -452,6 +452,8 @@ class TokenizerManager:
|
||||
top_logprobs_num,
|
||||
token_ids_logprob,
|
||||
obj.stream,
|
||||
bootstrap_host=obj.bootstrap_host,
|
||||
bootstrap_room=obj.bootstrap_room,
|
||||
lora_path=obj.lora_path,
|
||||
input_embeds=input_embeds,
|
||||
session_params=session_params,
|
||||
|
||||
Reference in New Issue
Block a user