[code move] move pp into a separate mixin (#11838)

This commit is contained in:
Lianmin Zheng
2025-10-20 18:46:56 -07:00
committed by GitHub
parent 1111030395
commit 01f14a7ad2
5 changed files with 425 additions and 422 deletions

View File

@@ -63,7 +63,6 @@ from sglang.srt.distributed import get_pp_group, get_world_group
from sglang.srt.environ import envs
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.managers.io_struct import (
AbortReq,
@@ -114,7 +113,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.mm_utils import init_embedding_cache
from sglang.srt.managers.overlap_utils import FutureIndices, FutureMap
from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
ModelWorkerBatch,
@@ -136,22 +135,21 @@ from sglang.srt.managers.scheduler_metrics_mixin import (
from sglang.srt.managers.scheduler_output_processor_mixin import (
SchedulerOutputProcessorMixin,
)
from sglang.srt.managers.scheduler_pp_mixin import SchedulerPPMixin
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
from sglang.srt.managers.scheduler_update_weights_mixin import (
SchedulerUpdateWeightsMixin,
)
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.managers.utils import GenerationBatchResult, validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.tracing.trace import (
process_tracing_init,
@@ -198,77 +196,6 @@ TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
@dataclass
class GenerationBatchResult:
logits_output: Optional[LogitsProcessorOutput] = None
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
next_token_ids: Optional[torch.Tensor] = None
num_accepted_tokens: Optional[int] = None
can_run_cuda_graph: bool = False
# For output processing
extend_input_len_per_req: Optional[List[int]] = None
extend_logprob_start_len_per_req: Optional[List[int]] = None
# For overlap scheduling
copy_done: Optional[torch.cuda.Event] = None
delay_sample_func: Optional[callable] = None
future_indices: Optional[FutureIndices] = None
# FIXME(lsyin): maybe move to a better place?
# sync path: forward stream -> output processor
accept_lens: Optional[torch.Tensor] = None
allocate_lens: Optional[torch.Tensor] = None
# relay path: forward stream -> next step forward
next_draft_input: Optional[EagleDraftInput] = None
def copy_to_cpu(self, return_logprob: bool = False):
"""Copy tensors to CPU in overlap scheduling.
Only the tensors which are needed for processing results are copied,
e.g., next_token_ids, logits outputs
"""
if return_logprob:
if self.logits_output.next_token_logits is not None:
self.logits_output.next_token_logits = (
self.logits_output.next_token_logits.to("cpu", non_blocking=True)
)
if self.logits_output.input_token_logprobs is not None:
self.logits_output.input_token_logprobs = (
self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
)
if self.logits_output.hidden_states is not None:
self.logits_output.hidden_states = self.logits_output.hidden_states.to(
"cpu", non_blocking=True
)
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
if self.accept_lens is not None:
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
if self.allocate_lens is not None:
self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True)
self.copy_done.record()
@classmethod
def from_pp_proxy(
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
):
# TODO(lsyin): refactor PP and avoid using dict
proxy_dict = next_pp_outputs.tensors
return cls(
logits_output=logits_output,
pp_hidden_states_proxy_tensors=None,
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
extend_logprob_start_len_per_req=proxy_dict.get(
"extend_logprob_start_len_per_req", None
),
can_run_cuda_graph=can_run_cuda_graph,
)
@dataclass
class EmbeddingBatchResult:
embeddings: torch.Tensor
@@ -281,6 +208,7 @@ class Scheduler(
SchedulerMetricsMixin,
SchedulerDisaggregationDecodeMixin,
SchedulerDisaggregationPrefillMixin,
SchedulerPPMixin,
):
"""A scheduler that manages a tensor parallel GPU worker."""
@@ -1058,128 +986,6 @@ class Scheduler(
self.launch_batch_sample_if_needed(batch_result)
self.last_batch = batch
@DynamicGradMode()
def event_loop_pp(self):
"""A non-overlap scheduler loop for pipeline parallelism."""
mbs = [None] * self.pp_size
last_mbs = [None] * self.pp_size
self.running_mbs = [
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
]
pp_outputs: Optional[PPProxyTensors] = None
while True:
server_is_idle = True
for mb_id in range(self.pp_size):
self.running_batch = self.running_mbs[mb_id]
self.last_batch = last_mbs[mb_id]
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
mbs[mb_id] = self.get_next_batch_to_run()
self.running_mbs[mb_id] = self.running_batch
self.cur_batch = mbs[mb_id]
if self.cur_batch:
server_is_idle = False
result = self.run_batch(self.cur_batch)
# (last rank) send the outputs to the next step
if self.pp_group.is_last_rank:
if self.cur_batch:
next_token_ids = result.next_token_ids
if self.cur_batch.return_logprob:
pp_outputs = PPProxyTensors(
{
"next_token_ids": next_token_ids,
"extend_input_len_per_req": result.extend_input_len_per_req,
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
}
| (
{
f"logits_output.{k}": v
for k, v in result.logits_output.__dict__.items()
}
if result.logits_output is not None
else {}
)
)
else:
pp_outputs = PPProxyTensors(
{
"next_token_ids": next_token_ids,
}
)
# send the output from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)
# receive outputs and post-process (filter finished reqs) the coming microbatch
next_mb_id = (mb_id + 1) % self.pp_size
next_pp_outputs = None
if mbs[next_mb_id] is not None:
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
self.pp_group.recv_tensor_dict(
all_gather_group=self.attn_tp_group
)
)
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
logits_output_args = {
k[len("logits_output.") :]: v
for k, v in next_pp_outputs.tensors.items()
if k.startswith("logits_output.")
}
if len(logits_output_args) > 0:
logits_output = LogitsProcessorOutput(**logits_output_args)
else:
logits_output = None
output_result = GenerationBatchResult.from_pp_proxy(
logits_output=logits_output,
next_pp_outputs=next_pp_outputs,
can_run_cuda_graph=result.can_run_cuda_graph,
)
self.process_batch_result(mbs[next_mb_id], output_result)
last_mbs[next_mb_id] = mbs[next_mb_id]
# (not last rank)
if not self.pp_group.is_last_rank:
# carry the outputs to the next stage
# send the outputs from the last round to let the next stage worker run post processing
if pp_outputs:
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)
# send out reqs to the next stage
dp_offset = self.attn_dp_rank * self.attn_tp_size
if self.attn_tp_rank == 0:
point_to_point_pyobj(
recv_reqs,
self.pp_rank * self.tp_size + dp_offset,
self.world_group.device_group,
self.pp_rank * self.tp_size + dp_offset,
(self.pp_rank + 1) * self.tp_size + dp_offset,
)
# send out proxy tensors to the next stage
if self.cur_batch:
# FIXME(lsyin): remove this assert
assert result.pp_hidden_states_proxy_tensors.tensors is not None
self.pp_group.send_tensor_dict(
result.pp_hidden_states_proxy_tensors.tensors,
all_gather_group=self.attn_tp_group,
)
pp_outputs = next_pp_outputs
# When the server is idle, self-check and re-init some states
if server_is_idle:
# When the server is idle, do self-check and re-init some states
self.self_check_during_idle()
def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""