[PD] Support PD disaggregation with Prefill PP (#8846)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: root <huzhiyuan@xiaohongshu.com>
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Francis <38564764+ssssnow@users.noreply.github.com>
Co-authored-by: zitto <zhjc1124@gmail.com>
This commit is contained in:
Shangming Cai
2025-08-17 09:31:31 +08:00
committed by GitHub
parent 6a9d6ca33c
commit 384f8ab5ce
11 changed files with 632 additions and 82 deletions

View File

@@ -2579,7 +2579,10 @@ def run_scheduler_process(
if scheduler.enable_overlap:
scheduler.event_loop_overlap_disagg_prefill()
else:
scheduler.event_loop_normal_disagg_prefill()
if server_args.pp_size > 1:
scheduler.event_loop_pp_disagg_prefill()
else:
scheduler.event_loop_normal_disagg_prefill()
elif disaggregation_mode == DisaggregationMode.DECODE:
if scheduler.enable_overlap:

View File

@@ -1,9 +1,16 @@
from __future__ import annotations
import logging
import multiprocessing as mp
from http import HTTPStatus
from typing import Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
if TYPE_CHECKING:
from sglang.srt.managers.scheduler import GenerationBatchResult
logger = logging.getLogger(__name__)
@@ -41,6 +48,57 @@ def validate_input_length(
return None
def get_logprob_dict_from_result(result: GenerationBatchResult) -> dict:
logits_output = result.logits_output
assert logits_output is not None
return {
"extend_input_len_per_req": result.extend_input_len_per_req,
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
"next_token_logprobs": result.logits_output.next_token_logprobs,
"next_token_top_logprobs_val": result.logits_output.next_token_top_logprobs_val,
"next_token_top_logprobs_idx": result.logits_output.next_token_top_logprobs_idx,
"next_token_token_ids_logprobs_val": result.logits_output.next_token_token_ids_logprobs_val,
"next_token_token_ids_logprobs_idx": result.logits_output.next_token_token_ids_logprobs_idx,
"input_token_logprobs": result.logits_output.input_token_logprobs,
"input_top_logprobs_val": result.logits_output.input_top_logprobs_val,
"input_top_logprobs_idx": result.logits_output.input_top_logprobs_idx,
"input_token_ids_logprobs_val": result.logits_output.input_token_ids_logprobs_val,
"input_token_ids_logprobs_idx": result.logits_output.input_token_ids_logprobs_idx,
}
def get_logprob_from_pp_outputs(
next_pp_outputs: PPProxyTensors,
) -> tuple[LogitsProcessorOutput, list[int], list[int]]:
logits_output = LogitsProcessorOutput(
# Do not send logits and hidden states because they are large
next_token_logits=None,
hidden_states=None,
next_token_logprobs=next_pp_outputs["next_token_logprobs"],
next_token_top_logprobs_val=next_pp_outputs["next_token_top_logprobs_val"],
next_token_top_logprobs_idx=next_pp_outputs["next_token_top_logprobs_idx"],
next_token_token_ids_logprobs_val=next_pp_outputs[
"next_token_token_ids_logprobs_val"
],
next_token_token_ids_logprobs_idx=next_pp_outputs[
"next_token_token_ids_logprobs_idx"
],
input_token_logprobs=next_pp_outputs["input_token_logprobs"],
input_top_logprobs_val=next_pp_outputs["input_top_logprobs_val"],
input_top_logprobs_idx=next_pp_outputs["input_top_logprobs_idx"],
input_token_ids_logprobs_val=next_pp_outputs["input_token_ids_logprobs_val"],
input_token_ids_logprobs_idx=next_pp_outputs["input_token_ids_logprobs_idx"],
)
extend_input_len_per_req = next_pp_outputs["extend_input_len_per_req"]
extend_logprob_start_len_per_req = next_pp_outputs[
"extend_logprob_start_len_per_req"
]
return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
class DPBalanceMeta:
"""
This class will be use in scheduler and dp controller