[PD] Support PD disaggregation with Prefill PP (#8846)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com> Signed-off-by: Shangming Cai <csmthu@gmail.com> Co-authored-by: root <huzhiyuan@xiaohongshu.com> Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Francis <38564764+ssssnow@users.noreply.github.com> Co-authored-by: zitto <zhjc1124@gmail.com>
This commit is contained in:
@@ -2579,7 +2579,10 @@ def run_scheduler_process(
|
||||
if scheduler.enable_overlap:
|
||||
scheduler.event_loop_overlap_disagg_prefill()
|
||||
else:
|
||||
scheduler.event_loop_normal_disagg_prefill()
|
||||
if server_args.pp_size > 1:
|
||||
scheduler.event_loop_pp_disagg_prefill()
|
||||
else:
|
||||
scheduler.event_loop_normal_disagg_prefill()
|
||||
|
||||
elif disaggregation_mode == DisaggregationMode.DECODE:
|
||||
if scheduler.enable_overlap:
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
||||
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,6 +48,57 @@ def validate_input_length(
|
||||
return None
|
||||
|
||||
|
||||
def get_logprob_dict_from_result(result: GenerationBatchResult) -> dict:
|
||||
|
||||
logits_output = result.logits_output
|
||||
assert logits_output is not None
|
||||
|
||||
return {
|
||||
"extend_input_len_per_req": result.extend_input_len_per_req,
|
||||
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
|
||||
"next_token_logprobs": result.logits_output.next_token_logprobs,
|
||||
"next_token_top_logprobs_val": result.logits_output.next_token_top_logprobs_val,
|
||||
"next_token_top_logprobs_idx": result.logits_output.next_token_top_logprobs_idx,
|
||||
"next_token_token_ids_logprobs_val": result.logits_output.next_token_token_ids_logprobs_val,
|
||||
"next_token_token_ids_logprobs_idx": result.logits_output.next_token_token_ids_logprobs_idx,
|
||||
"input_token_logprobs": result.logits_output.input_token_logprobs,
|
||||
"input_top_logprobs_val": result.logits_output.input_top_logprobs_val,
|
||||
"input_top_logprobs_idx": result.logits_output.input_top_logprobs_idx,
|
||||
"input_token_ids_logprobs_val": result.logits_output.input_token_ids_logprobs_val,
|
||||
"input_token_ids_logprobs_idx": result.logits_output.input_token_ids_logprobs_idx,
|
||||
}
|
||||
|
||||
|
||||
def get_logprob_from_pp_outputs(
|
||||
next_pp_outputs: PPProxyTensors,
|
||||
) -> tuple[LogitsProcessorOutput, list[int], list[int]]:
|
||||
logits_output = LogitsProcessorOutput(
|
||||
# Do not send logits and hidden states because they are large
|
||||
next_token_logits=None,
|
||||
hidden_states=None,
|
||||
next_token_logprobs=next_pp_outputs["next_token_logprobs"],
|
||||
next_token_top_logprobs_val=next_pp_outputs["next_token_top_logprobs_val"],
|
||||
next_token_top_logprobs_idx=next_pp_outputs["next_token_top_logprobs_idx"],
|
||||
next_token_token_ids_logprobs_val=next_pp_outputs[
|
||||
"next_token_token_ids_logprobs_val"
|
||||
],
|
||||
next_token_token_ids_logprobs_idx=next_pp_outputs[
|
||||
"next_token_token_ids_logprobs_idx"
|
||||
],
|
||||
input_token_logprobs=next_pp_outputs["input_token_logprobs"],
|
||||
input_top_logprobs_val=next_pp_outputs["input_top_logprobs_val"],
|
||||
input_top_logprobs_idx=next_pp_outputs["input_top_logprobs_idx"],
|
||||
input_token_ids_logprobs_val=next_pp_outputs["input_token_ids_logprobs_val"],
|
||||
input_token_ids_logprobs_idx=next_pp_outputs["input_token_ids_logprobs_idx"],
|
||||
)
|
||||
extend_input_len_per_req = next_pp_outputs["extend_input_len_per_req"]
|
||||
extend_logprob_start_len_per_req = next_pp_outputs[
|
||||
"extend_logprob_start_len_per_req"
|
||||
]
|
||||
|
||||
return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
|
||||
|
||||
|
||||
class DPBalanceMeta:
|
||||
"""
|
||||
This class will be use in scheduler and dp controller
|
||||
|
||||
Reference in New Issue
Block a user