From 458611de77d4db924f5d7d10d42e0194e387d509 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Fri, 3 Oct 2025 00:28:57 +0800 Subject: [PATCH] Unify forward output datastructure (#11124) --- python/sglang/srt/configs/model_config.py | 5 +- python/sglang/srt/disaggregation/prefill.py | 13 +- python/sglang/srt/managers/schedule_batch.py | 9 -- python/sglang/srt/managers/scheduler.py | 142 ++++++++++-------- .../srt/managers/scheduler_metrics_mixin.py | 5 + .../scheduler_output_processor_mixin.py | 3 +- python/sglang/srt/managers/tp_worker.py | 21 ++- .../srt/managers/tp_worker_overlap_thread.py | 24 ++- python/sglang/srt/managers/utils.py | 3 +- .../srt/model_executor/forward_batch_info.py | 11 ++ python/sglang/srt/speculative/eagle_worker.py | 46 +++--- python/sglang/srt/speculative/ngram_worker.py | 33 ++-- 12 files changed, 180 insertions(+), 135 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 9132fb428..f03573aac 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -22,6 +22,7 @@ from typing import List, Optional, Set, Union import torch from transformers import PretrainedConfig +from sglang.srt.environ import envs from sglang.srt.hf_transformers_utils import ( get_config, get_context_length, @@ -31,7 +32,7 @@ from sglang.srt.hf_transformers_utils import ( ) from sglang.srt.layers.quantization import QUANTIZATION_METHODS from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import get_bool_env_var, is_hip, retry +from sglang.srt.utils import is_hip, retry from sglang.utils import is_in_ci logger = logging.getLogger(__name__) @@ -237,7 +238,7 @@ class ModelConfig: f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config." ) if ( - get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") + envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.get() or is_in_ci() # FIXME: fix this special case ): logger.warning(msg) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index f31c5eeea..3393a32fb 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -689,7 +689,6 @@ class SchedulerDisaggregationPrefillMixin: self.running_mbs = [ ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size) ] - bids = [None] * self.pp_size pp_outputs: Optional[PPProxyTensors] = None # Either success or failed @@ -761,10 +760,7 @@ class SchedulerDisaggregationPrefillMixin: # send the outputs to the next step if self.pp_group.is_last_rank: if self.cur_batch: - next_token_ids, bids[mb_id] = ( - result.next_token_ids, - result.bid, - ) + next_token_ids = result.next_token_ids pp_outputs = PPProxyTensors( { "next_token_ids": next_token_ids, @@ -801,7 +797,6 @@ class SchedulerDisaggregationPrefillMixin: next_token_ids=next_pp_outputs["next_token_ids"], extend_input_len_per_req=None, extend_logprob_start_len_per_req=None, - bid=bids[next_mb_id], can_run_cuda_graph=result.can_run_cuda_graph, ) self.process_batch_result_disagg_prefill( @@ -818,8 +813,6 @@ class SchedulerDisaggregationPrefillMixin: # carry the outputs to the next stage if not self.pp_group.is_last_rank: - if self.cur_batch: - bids[mb_id] = result.bid if pp_outputs: # send the outputs from the last round to let the next stage worker run post processing self.pp_group.send_tensor_dict( @@ -838,8 +831,10 @@ class SchedulerDisaggregationPrefillMixin: # send out proxy tensors to the next stage if self.cur_batch: + # FIXME(lsyin): remove this assert + assert result.pp_hidden_states_proxy_tensors.tensors is not None self.pp_group.send_tensor_dict( - result.pp_hidden_states_proxy_tensors, + result.pp_hidden_states_proxy_tensors.tensors, all_gather_group=self.attn_tp_group, ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5d55abe0a..31b696f9f 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -860,10 +860,6 @@ class Req: ) -# Batch id -bid = 0 - - @dataclasses.dataclass class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): """Store all information of a batch on the scheduler.""" @@ -1829,10 +1825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu ) - global bid - bid += 1 return ModelWorkerBatch( - bid=bid, forward_mode=self.forward_mode, input_ids=self.input_ids, req_pool_indices=self.req_pool_indices, @@ -1952,8 +1945,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): @dataclasses.dataclass class ModelWorkerBatch: - # The batch id - bid: int # The forward mode forward_mode: ForwardMode # The input ids diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c71e937f7..8fb74093c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -150,7 +150,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache -from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors +from sglang.srt.model_executor.forward_batch_info import ( + ForwardBatchOutput, + ForwardMode, + PPProxyTensors, +) from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -175,7 +179,6 @@ from sglang.srt.utils import ( get_bool_env_var, get_int_env_var, get_zmq_socket, - is_cpu, kill_itself_when_parent_died, numa_bind_to_node, point_to_point_pyobj, @@ -194,24 +197,59 @@ logger = logging.getLogger(__name__) TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT") GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300)) -_is_cpu = is_cpu() - @dataclass class GenerationBatchResult: logits_output: Optional[LogitsProcessorOutput] - pp_hidden_states_proxy_tensors: Optional[torch.Tensor] + pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] next_token_ids: Optional[List[int]] + can_run_cuda_graph: bool + + # For output processing extend_input_len_per_req: List[int] extend_logprob_start_len_per_req: List[int] - bid: int - can_run_cuda_graph: bool + + @classmethod + def from_forward_batch_output( + cls, + forward_batch_output: ForwardBatchOutput, + extend_input_len_per_req: List[int], + extend_logprob_start_len_per_req: List[int], + ): + # TODO(lsyin): remove this workaround logic and try to unify output classes + + return cls( + logits_output=forward_batch_output.logits_output, + pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors, + next_token_ids=forward_batch_output.next_token_ids, + extend_input_len_per_req=extend_input_len_per_req, + extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, + can_run_cuda_graph=forward_batch_output.can_run_cuda_graph, + ) + + @classmethod + def from_pp_proxy( + cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph + ): + # TODO(lsyin): also simplify this logic + # Current PP implementation in scheduler is not compatible with ForwardBatchOutput + # Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP + proxy_dict = next_pp_outputs.tensors + return cls( + logits_output=logits_output, + pp_hidden_states_proxy_tensors=None, + next_token_ids=next_pp_outputs["next_token_ids"], + extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None), + extend_logprob_start_len_per_req=proxy_dict.get( + "extend_logprob_start_len_per_req", None + ), + can_run_cuda_graph=can_run_cuda_graph, + ) @dataclass class EmbeddingBatchResult: embeddings: torch.Tensor - bid: int class Scheduler( @@ -403,6 +441,12 @@ class Scheduler( else: self.draft_worker = None + # Dispatch the model worker + if self.spec_algorithm.is_none(): + self.model_worker = self.tp_worker + else: + self.model_worker = self.draft_worker + # Get token and memory info from the model worker ( self.max_total_num_tokens, @@ -959,7 +1003,6 @@ class Scheduler( self.running_mbs = [ ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size) ] - bids = [None] * self.pp_size pp_outputs: Optional[PPProxyTensors] = None while True: server_is_idle = True @@ -980,10 +1023,7 @@ class Scheduler( # (last rank) send the outputs to the next step if self.pp_group.is_last_rank: if self.cur_batch: - next_token_ids, bids[mb_id] = ( - result.next_token_ids, - result.bid, - ) + next_token_ids = result.next_token_ids if self.cur_batch.return_logprob: pp_outputs = PPProxyTensors( { @@ -1031,17 +1071,10 @@ class Scheduler( logits_output = LogitsProcessorOutput(**logits_output_args) else: logits_output = None - output_result = GenerationBatchResult( + + output_result = GenerationBatchResult.from_pp_proxy( logits_output=logits_output, - pp_hidden_states_proxy_tensors=None, - next_token_ids=next_pp_outputs["next_token_ids"], - extend_input_len_per_req=next_pp_outputs.tensors.get( - "extend_input_len_per_req", None - ), - extend_logprob_start_len_per_req=next_pp_outputs.tensors.get( - "extend_logprob_start_len_per_req", None - ), - bid=bids[next_mb_id], + next_pp_outputs=next_pp_outputs, can_run_cuda_graph=result.can_run_cuda_graph, ) self.process_batch_result(mbs[next_mb_id], output_result) @@ -1049,8 +1082,6 @@ class Scheduler( # (not last rank) if not self.pp_group.is_last_rank: - if self.cur_batch: - bids[mb_id] = result.bid # carry the outputs to the next stage # send the outputs from the last round to let the next stage worker run post processing if pp_outputs: @@ -1072,8 +1103,10 @@ class Scheduler( # send out proxy tensors to the next stage if self.cur_batch: + # FIXME(lsyin): remove this assert + assert result.pp_hidden_states_proxy_tensors.tensors is not None self.pp_group.send_tensor_dict( - result.pp_hidden_states_proxy_tensors, + result.pp_hidden_states_proxy_tensors.tensors, all_gather_group=self.attn_tp_group, ) @@ -2016,33 +2049,25 @@ class Scheduler( # Run forward if self.is_generation: + + batch_or_worker_batch = batch + if self.spec_algorithm.is_none(): - model_worker_batch = batch.get_model_worker_batch() + # FIXME(lsyin): remove this if and finally unify the abstraction + batch_or_worker_batch = batch.get_model_worker_batch() - if self.pp_group.is_last_rank: - logits_output, next_token_ids, can_run_cuda_graph = ( - self.tp_worker.forward_batch_generation(model_worker_batch) - ) - else: - pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = ( - self.tp_worker.forward_batch_generation(model_worker_batch) - ) - bid = model_worker_batch.bid - else: - ( - logits_output, - next_token_ids, - bid, - num_accepted_tokens, - can_run_cuda_graph, - ) = self.draft_worker.forward_batch_speculative_generation(batch) - bs = batch.batch_size() - self.spec_num_total_accepted_tokens += num_accepted_tokens + bs - self.spec_num_total_forward_ct += bs - self.num_generated_tokens += num_accepted_tokens + forward_batch_output = self.model_worker.forward_batch_generation( + batch_or_worker_batch + ) - if self.pp_group.is_last_rank: - batch.output_ids = next_token_ids + if not self.spec_algorithm.is_none(): + # TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing + self.udpate_spec_metrics( + batch.batch_size(), forward_batch_output.num_accepted_tokens + ) + + # update batch's output ids + batch.output_ids = forward_batch_output.next_token_ids # These 2 values are needed for processing the output, but the values can be # modified by overlap schedule. So we have to copy them here so that @@ -2051,6 +2076,7 @@ class Scheduler( extend_input_len_per_req = [req.extend_input_len for req in batch.reqs] else: extend_input_len_per_req = None + if batch.return_logprob: extend_logprob_start_len_per_req = [ req.extend_logprob_start_len for req in batch.reqs @@ -2058,25 +2084,15 @@ class Scheduler( else: extend_logprob_start_len_per_req = None - ret = GenerationBatchResult( - logits_output=logits_output if self.pp_group.is_last_rank else None, - pp_hidden_states_proxy_tensors=( - pp_hidden_states_proxy_tensors - if not self.pp_group.is_last_rank - else None - ), - next_token_ids=next_token_ids if self.pp_group.is_last_rank else None, + return GenerationBatchResult.from_forward_batch_output( + forward_batch_output=forward_batch_output, extend_input_len_per_req=extend_input_len_per_req, extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, - bid=bid, - can_run_cuda_graph=can_run_cuda_graph, ) else: # embedding or reward model model_worker_batch = batch.get_model_worker_batch() embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) - ret = EmbeddingBatchResult( - embeddings=embeddings, bid=model_worker_batch.bid - ) + ret = EmbeddingBatchResult(embeddings=embeddings) return ret def process_batch_result( diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index 5966925df..7e1154dc2 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -80,6 +80,11 @@ class SchedulerMetricsMixin: kv_events_config, self.attn_dp_rank ) + def udpate_spec_metrics(self, bs: int, num_accepted_tokens: int): + self.spec_num_total_accepted_tokens += num_accepted_tokens + bs + self.spec_num_total_forward_ct += bs + self.num_generated_tokens += num_accepted_tokens + def log_prefill_stats( self: Scheduler, adder: PrefillAdder, diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 316c29dd6..537dedc95 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -173,8 +173,7 @@ class SchedulerOutputProcessorMixin: self.set_next_batch_sampling_info_done(batch) else: # embedding or reward model - embeddings, bid = result.embeddings, result.bid - embeddings = embeddings.tolist() + embeddings = result.embeddings.tolist() # Check finish conditions for i, req in enumerate(batch.reqs): diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 0d3f76658..a4e087e5b 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -43,7 +43,11 @@ from sglang.srt.managers.io_struct import ( from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_executor.forward_batch_info import ( + ForwardBatch, + ForwardBatchOutput, + PPProxyTensors, +) from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.server_args import ServerArgs @@ -234,9 +238,7 @@ class TpModelWorker: model_worker_batch: ModelWorkerBatch, launch_done: Optional[threading.Event] = None, skip_sample: bool = False, - ) -> Tuple[ - Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool - ]: + ) -> ForwardBatchOutput: # update the consumer index of hicache to the running batch self.set_hicache_consumer(model_worker_batch.hicache_consumer_index) @@ -271,13 +273,20 @@ class TpModelWorker: else: next_token_ids = self.model_runner.sample(logits_output, forward_batch) - return logits_output, next_token_ids, can_run_cuda_graph + return ForwardBatchOutput( + logits_output=logits_output, + next_token_ids=next_token_ids, + can_run_cuda_graph=can_run_cuda_graph, + ) else: pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward( forward_batch, pp_proxy_tensors=pp_proxy_tensors, ) - return pp_proxy_tensors.tensors, None, can_run_cuda_graph + return ForwardBatchOutput( + pp_proxy_tensors=pp_proxy_tensors, + can_run_cuda_graph=can_run_cuda_graph, + ) def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 9ca68b0b8..1af05a434 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -39,6 +39,7 @@ from sglang.srt.managers.io_struct import ( from sglang.srt.managers.overlap_utils import FutureMap from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput from sglang.srt.server_args import ServerArgs from sglang.srt.utils import DynamicGradMode from sglang.utils import get_exception_traceback @@ -160,13 +161,17 @@ class TpModelWorkerClient: self.future_map.resolve_future(model_worker_batch) # Run forward + forward_batch_output = self.worker.forward_batch_generation( + model_worker_batch, + model_worker_batch.launch_done, + # Skip sampling for prefill-only requests + skip_sample=model_worker_batch.is_prefill_only, + ) + logits_output, next_token_ids, can_run_cuda_graph = ( - self.worker.forward_batch_generation( - model_worker_batch, - model_worker_batch.launch_done, - # Skip sampling for prefill-only requests - skip_sample=model_worker_batch.is_prefill_only, - ) + forward_batch_output.logits_output, + forward_batch_output.next_token_ids, + forward_batch_output.can_run_cuda_graph, ) # Update the future token ids map @@ -227,7 +232,7 @@ class TpModelWorkerClient: def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch - ) -> Tuple[None, torch.Tensor, bool]: + ) -> ForwardBatchOutput: # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch. sampling_info = model_worker_batch.sampling_info sampling_info.update_penalties() @@ -250,7 +255,10 @@ class TpModelWorkerClient: future_next_token_ids = self.future_map.update_next_future( cur_future_map_ct, bs ) - return None, future_next_token_ids, False + return ForwardBatchOutput( + next_token_ids=future_next_token_ids, + can_run_cuda_graph=False, + ) def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): success, message = self.worker.update_weights_from_disk(recv_req) diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py index de83c4590..e25f2628f 100644 --- a/python/sglang/srt/managers/utils.py +++ b/python/sglang/srt/managers/utils.py @@ -2,11 +2,10 @@ from __future__ import annotations import logging import multiprocessing as mp -from http import HTTPStatus from typing import TYPE_CHECKING, Dict, List, Optional from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req +from sglang.srt.managers.schedule_batch import Req from sglang.srt.model_executor.forward_batch_info import PPProxyTensors if TYPE_CHECKING: diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 52e96016d..fce792a04 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -900,6 +900,17 @@ class ForwardBatch: return self.tbo_split_seq_index is not None +@dataclass +class ForwardBatchOutput: + # FIXME(lsyin): unify the forward batch output between different spec and parallelism + # need to be more organized + logits_output: Optional[torch.Tensor] = None + next_token_ids: Optional[torch.Tensor] = None + num_accepted_tokens: Optional[int] = None + pp_proxy_tensors: Optional[PPProxyTensors] = None + can_run_cuda_graph: bool = False + + def enable_num_token_non_padded(server_args): return get_moe_expert_parallel_world_size() > 1 diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index f115f3eb8..82bfaa276 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -14,7 +14,6 @@ from sglang.srt.distributed import ( ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs -from sglang.srt.managers.mm_utils import embed_mm_inputs from sglang.srt.managers.schedule_batch import ( ScheduleBatch, get_last_loc, @@ -24,6 +23,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, + ForwardBatchOutput, ForwardMode, ) from sglang.srt.server_args import ServerArgs @@ -422,9 +422,7 @@ class EAGLEWorker(TpModelWorker): def draft_model_runner(self): return self.model_runner - def forward_batch_speculative_generation( - self, batch: ScheduleBatch - ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]: + def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput: """Run speculative decoding forward. NOTE: Many states of batch is modified as you go through. It is not guaranteed that @@ -437,14 +435,19 @@ class EAGLEWorker(TpModelWorker): the batch id (used for overlap schedule), and number of accepted tokens. """ if batch.forward_mode.is_extend() or batch.is_extend_in_batch: - logits_output, next_token_ids, bid, seq_lens_cpu = ( - self.forward_target_extend(batch) + logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend( + batch ) with self.draft_tp_context(self.draft_model_runner.tp_group): self.forward_draft_extend( batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu ) - return logits_output, next_token_ids, bid, 0, False + return ForwardBatchOutput( + logits_output=logits_output, + next_token_ids=next_token_ids, + num_accepted_tokens=0, + can_run_cuda_graph=False, + ) else: with self.draft_tp_context(self.draft_model_runner.tp_group): spec_info = self.draft(batch) @@ -462,12 +465,11 @@ class EAGLEWorker(TpModelWorker): # decode is not finished self.forward_draft_extend_after_decode(batch) - return ( - logits_output, - verify_output.verified_id, - model_worker_batch.bid, - sum(verify_output.accept_length_per_req_cpu), - can_run_cuda_graph, + return ForwardBatchOutput( + logits_output=logits_output, + next_token_ids=verify_output.verified_id, + num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu), + can_run_cuda_graph=can_run_cuda_graph, ) def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch): @@ -499,19 +501,21 @@ class EAGLEWorker(TpModelWorker): Returns: logits_output: The output of logits. It will contain the full hidden states. next_token_ids: Next token ids generated. - bid: The model batch ID. Used for overlap schedule. """ # Forward with the target model and get hidden states. # We need the full hidden states to prefill the KV cache of the draft model. model_worker_batch = batch.get_model_worker_batch() model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL - logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation( + forward_batch_output = self.target_worker.forward_batch_generation( model_worker_batch ) + logits_output, next_token_ids = ( + forward_batch_output.logits_output, + forward_batch_output.next_token_ids, + ) return ( logits_output, next_token_ids, - model_worker_batch.bid, model_worker_batch.seq_lens_cpu, ) @@ -811,10 +815,12 @@ class EAGLEWorker(TpModelWorker): ).cpu() # Forward - logits_output, _, can_run_cuda_graph = ( - self.target_worker.forward_batch_generation( - model_worker_batch, skip_sample=True - ) + forward_batch_output = self.target_worker.forward_batch_generation( + model_worker_batch, skip_sample=True + ) + logits_output, can_run_cuda_graph = ( + forward_batch_output.logits_output, + forward_batch_output.can_run_cuda_graph, ) vocab_mask = None diff --git a/python/sglang/srt/speculative/ngram_worker.py b/python/sglang/srt/speculative/ngram_worker.py index 69dc83b1f..473e040d2 100644 --- a/python/sglang/srt/speculative/ngram_worker.py +++ b/python/sglang/srt/speculative/ngram_worker.py @@ -7,7 +7,7 @@ from sgl_kernel.speculative import reconstruct_indices_from_tree_mask from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.tp_worker import TpModelWorker -from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput, ForwardMode from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache from sglang.srt.speculative.ngram_utils import NgramVerifyInput @@ -207,17 +207,18 @@ class NGRAMWorker: batch_tokens.append(put_ids) self.ngram_cache.batch_put(batch_tokens) - def forward_batch_speculative_generation(self, batch: ScheduleBatch): + def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput: self._prepare_for_speculative_decoding(batch) model_worker_batch = batch.get_model_worker_batch() - bid = model_worker_batch.bid num_accepted_tokens = 0 if model_worker_batch.forward_mode.is_target_verify(): - logits_output, _, can_run_cuda_graph = ( - self.target_worker.forward_batch_generation( - model_worker_batch, skip_sample=True - ) + forward_batch_output = self.target_worker.forward_batch_generation( + model_worker_batch, skip_sample=True + ) + logits_output, can_run_cuda_graph = ( + forward_batch_output.logits_output, + forward_batch_output.can_run_cuda_graph, ) verify_input = model_worker_batch.spec_info logits_output, next_token_ids, num_accepted_tokens = verify_input.verify( @@ -227,14 +228,18 @@ class NGRAMWorker: batch.forward_mode = ForwardMode.DECODE else: + forward_batch_output = self.target_worker.forward_batch_generation( + model_worker_batch + ) logits_output, next_token_ids, can_run_cuda_graph = ( - self.target_worker.forward_batch_generation(model_worker_batch) + forward_batch_output.logits_output, + forward_batch_output.next_token_ids, + forward_batch_output.can_run_cuda_graph, ) - return ( - logits_output, - next_token_ids, - bid, - num_accepted_tokens, - can_run_cuda_graph, + return ForwardBatchOutput( + logits_output=logits_output, + next_token_ids=next_token_ids, + num_accepted_tokens=num_accepted_tokens, + can_run_cuda_graph=can_run_cuda_graph, )