Unify forward output datastructure (#11124)

This commit is contained in:
Liangsheng Yin
2025-10-03 00:28:57 +08:00
committed by GitHub
parent 3511b37099
commit 458611de77
12 changed files with 180 additions and 135 deletions

View File

@@ -150,7 +150,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatchOutput,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -175,7 +179,6 @@ from sglang.srt.utils import (
get_bool_env_var,
get_int_env_var,
get_zmq_socket,
is_cpu,
kill_itself_when_parent_died,
numa_bind_to_node,
point_to_point_pyobj,
@@ -194,24 +197,59 @@ logger = logging.getLogger(__name__)
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
_is_cpu = is_cpu()
@dataclass
class GenerationBatchResult:
logits_output: Optional[LogitsProcessorOutput]
pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
next_token_ids: Optional[List[int]]
can_run_cuda_graph: bool
# For output processing
extend_input_len_per_req: List[int]
extend_logprob_start_len_per_req: List[int]
bid: int
can_run_cuda_graph: bool
@classmethod
def from_forward_batch_output(
cls,
forward_batch_output: ForwardBatchOutput,
extend_input_len_per_req: List[int],
extend_logprob_start_len_per_req: List[int],
):
# TODO(lsyin): remove this workaround logic and try to unify output classes
return cls(
logits_output=forward_batch_output.logits_output,
pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
next_token_ids=forward_batch_output.next_token_ids,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
)
@classmethod
def from_pp_proxy(
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
):
# TODO(lsyin): also simplify this logic
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
proxy_dict = next_pp_outputs.tensors
return cls(
logits_output=logits_output,
pp_hidden_states_proxy_tensors=None,
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
extend_logprob_start_len_per_req=proxy_dict.get(
"extend_logprob_start_len_per_req", None
),
can_run_cuda_graph=can_run_cuda_graph,
)
@dataclass
class EmbeddingBatchResult:
embeddings: torch.Tensor
bid: int
class Scheduler(
@@ -403,6 +441,12 @@ class Scheduler(
else:
self.draft_worker = None
# Dispatch the model worker
if self.spec_algorithm.is_none():
self.model_worker = self.tp_worker
else:
self.model_worker = self.draft_worker
# Get token and memory info from the model worker
(
self.max_total_num_tokens,
@@ -959,7 +1003,6 @@ class Scheduler(
self.running_mbs = [
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
]
bids = [None] * self.pp_size
pp_outputs: Optional[PPProxyTensors] = None
while True:
server_is_idle = True
@@ -980,10 +1023,7 @@ class Scheduler(
# (last rank) send the outputs to the next step
if self.pp_group.is_last_rank:
if self.cur_batch:
next_token_ids, bids[mb_id] = (
result.next_token_ids,
result.bid,
)
next_token_ids = result.next_token_ids
if self.cur_batch.return_logprob:
pp_outputs = PPProxyTensors(
{
@@ -1031,17 +1071,10 @@ class Scheduler(
logits_output = LogitsProcessorOutput(**logits_output_args)
else:
logits_output = None
output_result = GenerationBatchResult(
output_result = GenerationBatchResult.from_pp_proxy(
logits_output=logits_output,
pp_hidden_states_proxy_tensors=None,
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=next_pp_outputs.tensors.get(
"extend_input_len_per_req", None
),
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
"extend_logprob_start_len_per_req", None
),
bid=bids[next_mb_id],
next_pp_outputs=next_pp_outputs,
can_run_cuda_graph=result.can_run_cuda_graph,
)
self.process_batch_result(mbs[next_mb_id], output_result)
@@ -1049,8 +1082,6 @@ class Scheduler(
# (not last rank)
if not self.pp_group.is_last_rank:
if self.cur_batch:
bids[mb_id] = result.bid
# carry the outputs to the next stage
# send the outputs from the last round to let the next stage worker run post processing
if pp_outputs:
@@ -1072,8 +1103,10 @@ class Scheduler(
# send out proxy tensors to the next stage
if self.cur_batch:
# FIXME(lsyin): remove this assert
assert result.pp_hidden_states_proxy_tensors.tensors is not None
self.pp_group.send_tensor_dict(
result.pp_hidden_states_proxy_tensors,
result.pp_hidden_states_proxy_tensors.tensors,
all_gather_group=self.attn_tp_group,
)
@@ -2016,33 +2049,25 @@ class Scheduler(
# Run forward
if self.is_generation:
batch_or_worker_batch = batch
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
# FIXME(lsyin): remove this if and finally unify the abstraction
batch_or_worker_batch = batch.get_model_worker_batch()
if self.pp_group.is_last_rank:
logits_output, next_token_ids, can_run_cuda_graph = (
self.tp_worker.forward_batch_generation(model_worker_batch)
)
else:
pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
self.tp_worker.forward_batch_generation(model_worker_batch)
)
bid = model_worker_batch.bid
else:
(
logits_output,
next_token_ids,
bid,
num_accepted_tokens,
can_run_cuda_graph,
) = self.draft_worker.forward_batch_speculative_generation(batch)
bs = batch.batch_size()
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
self.spec_num_total_forward_ct += bs
self.num_generated_tokens += num_accepted_tokens
forward_batch_output = self.model_worker.forward_batch_generation(
batch_or_worker_batch
)
if self.pp_group.is_last_rank:
batch.output_ids = next_token_ids
if not self.spec_algorithm.is_none():
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
self.udpate_spec_metrics(
batch.batch_size(), forward_batch_output.num_accepted_tokens
)
# update batch's output ids
batch.output_ids = forward_batch_output.next_token_ids
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
@@ -2051,6 +2076,7 @@ class Scheduler(
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
else:
extend_input_len_per_req = None
if batch.return_logprob:
extend_logprob_start_len_per_req = [
req.extend_logprob_start_len for req in batch.reqs
@@ -2058,25 +2084,15 @@ class Scheduler(
else:
extend_logprob_start_len_per_req = None
ret = GenerationBatchResult(
logits_output=logits_output if self.pp_group.is_last_rank else None,
pp_hidden_states_proxy_tensors=(
pp_hidden_states_proxy_tensors
if not self.pp_group.is_last_rank
else None
),
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
return GenerationBatchResult.from_forward_batch_output(
forward_batch_output=forward_batch_output,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
bid=bid,
can_run_cuda_graph=can_run_cuda_graph,
)
else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult(
embeddings=embeddings, bid=model_worker_batch.bid
)
ret = EmbeddingBatchResult(embeddings=embeddings)
return ret
def process_batch_result(